Skip to content
Snippets Groups Projects
test_roiset.py 23.1 KiB
Newer Older
import numpy as np
from model_server.conf.testing import output_path, roiset_test_data
from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams
from model_server.base.roiset import RoiSet
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
from model_server.base.models import DummyInstanceSegmentationModel
class BaseTestRoiSetMonoProducts(object):

    def setUp(self) -> None:
        # set up test raw data and segmentation from file
        self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])
        self.stack_ch_pa = self.stack.get_mono(roiset_test_data['pipeline_params']['patches_channel'])
        self.seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path'])

class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
    def _make_roi_set(self, mask_type='boxes', **kwargs):
                mask_type=mask_type,
                filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}),
    def test_roi_mask_shape(self, **kwargs):
        roiset = self._make_roi_set(**kwargs)

        # all masks' bounding boxes are at least as big as ROI area
        for roi in roiset.get_df().itertuples():
            self.assertEqual(roi.binary_mask.dtype, 'bool')
            sh = roi.binary_mask.shape
            self.assertEqual(sh, (roi.h, roi.w))
            self.assertGreaterEqual(sh[0] * sh[1], roi.area)

    def test_roi_zmask(self, **kwargs):
        roiset = self._make_roi_set(**kwargs)
        zmask_acc = InMemoryDataAccessor(zmask)
        self.assertTrue(zmask_acc.is_mask())

        # assert dimensionality of zmask
        self.assertEqual(zmask_acc.nz, roiset.acc_raw.nz)
        self.assertEqual(zmask_acc.chroma, 1)
        write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc)

        # mask values are not just all True or all False
        self.assertTrue(np.any(zmask))
        self.assertFalse(np.all(zmask))
Christopher Randolph Rhodes's avatar
Christopher Randolph Rhodes committed
    def test_roiset_from_non_zstacks(self, **kwargs):
        acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
        self.assertEqual(acc_zstack_slice.nz, 1)
        roiset = RoiSet.from_segmentation(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
        zmask = roiset.get_zmask()

        zmask_acc = InMemoryDataAccessor(zmask)
        self.assertTrue(zmask_acc.is_mask())

    def test_create_roiset_with_no_objects(self):
        zero_obmap = InMemoryDataAccessor(np.zeros(self.seg_mask.shape, self.seg_mask.dtype))
        roiset = RoiSet(self.stack_ch_pa, zero_obmap)
Christopher Randolph Rhodes's avatar
Christopher Randolph Rhodes committed
    def test_slices_are_valid(self):
            ebb = roiset.acc_raw.data[s]
            self.assertEqual(len(ebb.shape), 4)
            self.assertTrue(np.all([si >= 1 for si in ebb.shape]))

    def test_dataframe_and_mask_array_in_iterator(self):
        roiset = self._make_roi_set()
        for roi in roiset:
            self.assertEqual(ma.shape, (roi.h, roi.w))
Christopher Randolph Rhodes's avatar
Christopher Randolph Rhodes committed
    def test_rel_slices_are_valid(self):
            ebb = roiset.acc_raw.data[roi.expanded_slice]
            self.assertEqual(len(ebb.shape), 4)
            self.assertTrue(np.all([si >= 1 for si in ebb.shape]))
            rbb = ebb[roi.relative_slice]
            self.assertEqual(len(rbb.shape), 4)
            self.assertTrue(np.all([si >= 1 for si in rbb.shape]))

    def test_make_expanded_2d_patches(self):
        roiset = self._make_roi_set()
        where = output_path / 'expanded_2d_patches'
        df = roiset.get_df()
            acc = generate_file_accessor(where / f)
            la = int(re.search(r'la([\d]+)', str(f)).group(1))
            roi_q = df.loc[df.label == la, :]
            self.assertEqual(len(roi_q), 1)
            self.assertEqual((256, 256), acc.hw)

    def test_make_tight_2d_patches(self):
        roiset = self._make_roi_set()
        where = output_path / 'tight_2d_patches'
            draw_bounding_box=True,
            expanded=False
        )
        df = roiset.get_df()
        for f in df_res.patch_path:  # all exported files are same shape as bounding boxes in RoiSet's datatable
            acc = generate_file_accessor(where / f)
            la = int(re.search(r'la([\d]+)', str(f)).group(1))
            roi_q = df.loc[df.label == la, :]
            self.assertEqual(len(roi_q), 1)
            roi = roi_q.iloc[0]
            self.assertEqual((roi.h, roi.w), acc.hw)

    def test_make_expanded_3d_patches(self):
        roiset = self._make_roi_set()
        where = output_path / '3d_patches'
            make_3d=True,
            expanded=True
        )
        self.assertGreaterEqual(len(df_res), 1)
        for f in df_res.patch_path:
            acc = generate_file_accessor(where / f)
            self.assertGreater(acc.nz, 1)

    def test_export_annotated_zstack(self):
        roiset = self._make_roi_set()
        where = output_path / 'annotated_zstack'
        file = roiset.export_annotated_zstack(
        result = generate_file_accessor(where / file)
        self.assertEqual(result.shape, roiset.acc_raw.shape)

        roiset = RoiSet.from_segmentation(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
        from model_server.base.roiset import project_stack_from_focal_points
            df['centroid-0'].to_numpy(),
            df['centroid-1'].to_numpy(),
            df['zi'].to_numpy(),
        )

        self.assertEqual(img.shape[0:2], self.stack.shape[0:2])

        write_accessor_data_to_file(
            output_path / 'flattened.tif',
            InMemoryDataAccessor(img)
Christopher Randolph Rhodes's avatar
Christopher Randolph Rhodes committed
    def test_make_binary_masks(self):
        roiset = self._make_roi_set()
        df_res = roiset.export_patch_masks(output_path / '2d_mask_patches', )
        for f in df_res.patch_mask_path:  # all exported files are same shape as bounding boxes in RoiSet's datatable
            acc = generate_file_accessor(output_path / '2d_mask_patches' / f)
            la = int(re.search(r'la([\d]+)', str(f)).group(1))
            roi_q = df.loc[df.label == la, :]
            self.assertEqual(len(roi_q), 1)
            roi = roi_q.iloc[0]
            self.assertEqual((roi.h, roi.w), acc.hw)
    def test_classify_by(self):
        roiset = self._make_roi_set()
        roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel())
        self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
        self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
        return roiset

    def test_classify_by_multiple_channels(self):
        roiset = RoiSet.from_segmentation(self.stack, self.seg_mask)
        roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
        self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
        self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
        return roiset

    def test_classify_by_with_derived_channel(self):
        class ModelWithDerivedInputs(DummyInstanceSegmentationModel):
            def infer(self, img, mask):
                return PatchStack(super().infer(img, mask).data * img.chroma)

        roiset = RoiSet.from_segmentation(
            self.stack,
            self.seg_mask,
            params=RoiSetMetaParams(
                filters={'area': {'min': 1e3, 'max': 1e4}},
            )
        )
                lambda acc: PatchStack(2 * acc.get_channels([0]).data),
                lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8'))
        self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4]))
        self.assertTrue(all(np.unique(roiset.object_class_maps['multiple_input_model'].data) == [0, 4]))

        self.assertEqual(len(roiset.accs_derived), 2)
        for di in roiset.accs_derived:
            self.assertEqual(roiset.get_patches_acc().hw, di.hw)
            self.assertEqual(roiset.get_patches_acc().nz, di.nz)
            self.assertEqual(roiset.get_patches_acc().count, di.count)

        dpas = roiset.run_exports(output_path / 'derived_channels', 0, 'der', RoiSetExportParams(derived_channels=True))
        for fp in dpas['derived_channels']:
            assert Path(fp).exists()
        return roiset

    def test_export_object_classes(self):
        record = self.test_classify_by().run_exports(
            output_path / 'object_class_maps',
            0,
            'obmap',
            RoiSetExportParams(object_classes=True)
        )
        self.assertTrue(Path(opa).exists())
        self.assertTrue(all(np.unique(acc.data) == [0, 1]))

    def test_raw_patches_are_correct_shape(self):
        roiset = self._make_roi_set()
        np, h, w, nc, nz = patches.shape
        self.assertEqual(np, roiset.count)
        self.assertEqual(nc, roiset.acc_raw.chroma)

    def test_patch_masks_are_correct_shape(self):
        roiset = self._make_roi_set()
        df_patch_masks = roiset.get_patch_masks()
        for roi in df_patch_masks.itertuples():
            h, w, nc, nz = roi.patch_mask.shape
            self.assertEqual(nc, 1)
            self.assertEqual(nz, 1)
            self.assertEqual(h, roi.h)
            self.assertEqual(w, roi.w)
class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):

    def setUp(self) -> None:
        super().setUp()
            params=RoiSetMetaParams(
                expand_box_by=(128, 2),
                mask_type='boxes',
                filters={'area': {'min': 1e3, 'max': 1e4}},
            )
        )

    def test_multichannel_to_mono_2d_patches(self):
        where = output_path / 'multichannel' / 'mono_2d_patches'
        df_res = self.roiset.export_patches(
            expanded=True,
            pad_to=256,
        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
    def test_multichannel_to_color_2d_patches(self):
        where = output_path / 'multichannel' / 'color_2d_patches'
        self.assertGreater(self.roiset.acc_raw.chroma, 1)
        patches_acc = self.roiset.get_patches_acc(channels=chs)
        self.assertEqual(patches_acc.chroma, len(chs))

        df_res = self.roiset.export_patches(
            where,
            channels=chs,
            draw_bounding_box=True,
            expanded=True,
            pad_to=256,
        )
        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
        self.assertEqual(result.chroma, len(chs))

    def test_multichannnel_to_mono_2d_patches_rgb_bbox(self):
        where = output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox'
        df_res = self.roiset.export_patches(
            white_channel=3,
            draw_bounding_box=True,
            bounding_box_channel=1,
            expanded=True,
            pad_to=256,
        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
    def test_multichannnel_to_rgb_2d_patches_bbox(self):
        where = output_path / 'multichannel' / 'rgb_2d_patches_bbox'
        df_res = self.roiset.export_patches(
            draw_mask=False,
            draw_bounding_box=True,
            bounding_box_channel=1,
            rgb_overlay_weights=(0.1, 1.0, 1.0),
            expanded=True,
            pad_to=256,
        )
        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
        self.assertEqual(result.chroma, 3)

    def test_multichannnel_to_rgb_2d_patches_mask(self):
        where = output_path / 'multichannel' / 'rgb_2d_patches_mask'
        df_res = self.roiset.export_patches(
            rgb_overlay_weights=(0.1, 1.0, 1.0),
            expanded=True,
            pad_to=256,
        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
    def test_multichannnel_to_rgb_2d_patches_contour(self):
        where = output_path / 'multichannel' / 'rgb_2d_patches_contour'
        df_res = self.roiset.export_patches(
            rgb_overlay_channels=(3, None, None),
            rgb_overlay_weights=(0.1, 1.0, 1.0),
            expanded=True,
            pad_to=256,
        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
        self.assertEqual(result.get_mono(2).data.max(), 0)  # blue channel is black
    def test_multichannel_to_multichannel_tif_patches(self):
        where = output_path / 'multichannel' / 'multichannel_tif_patches'
        df_res = self.roiset.export_patches(
            expanded=True,
            pad_to=256,
        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
        self.assertEqual(result.nz, 1)
    def test_multichannel_annotated_zstack(self):
        where = output_path / 'multichannel' / 'annotated_zstack'
Christopher Randolph Rhodes's avatar
Christopher Randolph Rhodes committed
        file = self.roiset.export_annotated_zstack(
            'test_multichannel_annotated_zstack',
            expanded=True,
            pad_to=256,
        result = generate_file_accessor(where / file)
Christopher Randolph Rhodes's avatar
Christopher Randolph Rhodes committed
        self.assertEqual(result.chroma, self.stack.chroma)
        self.assertEqual(result.nz, self.stack.nz)
    def test_export_single_channel_annotated_zstack(self):
        where = output_path / 'annotated_zstack'
        file = self.roiset.export_annotated_zstack(
            expanded=True,
            pad_to=256,
        result = generate_file_accessor(where / file)
        self.assertEqual(result.hw, self.roiset.acc_raw.hw)
        self.assertEqual(result.nz, self.roiset.acc_raw.nz)
        self.assertEqual(result.chroma, 1)
    def test_run_exports(self):
        p = RoiSetExportParams(**{
            'patches_3d': {},
            'annotated_patches_2d': {
                'draw_bounding_box': True,
                'rgb_overlay_channels': [3, None, None],
                'rgb_overlay_weights': [0.2, 1.0, 1.0],
                'pad_to': 512,
            },
            'patches_2d': {
                'draw_bounding_box': False,
                'draw_mask': False,
            },
            'patch_masks': {
                'pad_to': 256,
            },
            'annotated_zstacks': {},
            'object_classes': True,
            'dataframe': True,
        })

        where = output_path / 'run_exports'
        for k, v in res.items():
            if isinstance(v, list):
                for f in v:
                    self.assertFalse(Path(f).is_absolute())
                    self.assertTrue((where / f).exists())
                self.assertFalse(Path(v).is_absolute())
                self.assertTrue((where / v).exists())
        # test on paths in CSV
        test_df = pd.read_csv(where / res['dataframe'])
        for c in ['tight_patch_masks_path', 'patch_path']:
            self.assertTrue(c in test_df.columns)
            for f in test_df[c]:
                self.assertTrue((where / f).exists(), where / f)
    def test_run_export_expanded_2d_patch(self):
        p = RoiSetExportParams(**{
            'patches_2d': {
                'draw_bounding_box': False,
                'draw_mask': False,
                'expanded': True,
                'pad_to': 256,
            },
        })
        self.assertTrue(hasattr(p.patches_2d, 'pad_to'))
        self.assertTrue(hasattr(p.patches_2d, 'expanded'))

        where = output_path / 'run_exports_expanded_2d_patch'
        res = self.roiset.run_exports(
            where,
            channel=-1,
            prefix='test',
            params=p
        )

        # test that exported patches are padded dimension
        for fn in res['patches_2d']:
            pa = where / fn
            self.assertTrue(pa.exists())
            pacc = generate_file_accessor(pa)
            self.assertEqual(pacc.hw, (256, 256))

    def test_run_export_mono_2d_patch(self):
        p = RoiSetExportParams(**{
            'patches_2d': {
                'draw_bounding_box': False,
                'draw_mask': False,
                'expanded': True,
                'pad_to': 256,
                'rgb_overlay_channels': None,
            },
        })
        self.assertTrue(hasattr(p.patches_2d, 'pad_to'))
        self.assertTrue(hasattr(p.patches_2d, 'expanded'))

        where = output_path / 'run_exports_mono_2d_patch'
        res = self.roiset.run_exports(
            where,
            channel=-1,
            prefix='test',
            params=p
        )

        # test that exported patches are padded dimension
        for fn in res['patches_2d']:
            pa = where / fn
            self.assertTrue(pa.exists())
            pacc = generate_file_accessor(pa)
            self.assertEqual(pacc.chroma, 1)
        print('res')

from model_server.base.roiset import _get_label_ids

    def setUp(self) -> None:
        # set up test raw data and segmentation from file
        self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])
        self.stack_ch_pa = self.stack.get_mono(roiset_test_data['pipeline_params']['segmentation_channel'])
        self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d'])

    @staticmethod
    def _label_is_2d(id_map, la):  # single label's zmask has same counts as its MIP
        mask_3d = (id_map == la)
        mask_mip = mask_3d.max(axis=-1)
        return mask_3d.sum() == mask_mip.sum()

    def test_id_map_connects_z(self):
        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True)
        labels = np.unique(id_map.data)[1:]
        is_2d = all([self._label_is_2d(id_map.data, la) for la in labels])
    def test_id_map_disconnects_z(self):
        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
        labels = np.unique(id_map.data)[1:]
        is_2d = all([self._label_is_2d(id_map.data, la) for la in labels])
        self.assertTrue(is_2d)
    def test_create_roiset_from_3d_obj_ids(self):
        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
        self.assertEqual(self.stack_ch_pa.shape, id_map.shape)

            self.stack_ch_pa,
            id_map,
            params=RoiSetMetaParams(mask_type='contours')
        )
        self.assertEqual(roiset.count, id_map.data.max())
        self.assertGreater(len(roiset.get_df()['zi'].unique()), 1)

    def test_create_roiset_from_2d_obj_ids(self):
        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False)
        self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3])
        self.assertEqual(id_map.nz, 1)

            self.stack_ch_pa,
            id_map,
            params=RoiSetMetaParams(mask_type='contours')
        )
        self.assertEqual(roiset.count, id_map.data.max())
        self.assertGreater(len(roiset.get_df()['zi'].unique()), 1)
        return roiset

    def test_create_roiset_from_df_and_patch_masks(self):
        ref_roiset = self.test_create_roiset_from_2d_obj_ids()
        ref_roiset.serialize(where_ser, prefix='ref')
        where_df = where_ser / 'dataframe' / 'ref.csv'
        self.assertTrue(where_df.exists())
        df_test = pd.read_csv(where_df)

        where_patch_masks = where_ser / 'tight_patch_masks'
        for pmf in where_patch_masks.iterdir():
            self.assertTrue(pmf.suffix.upper() == '.PNG')
            la = int(re.search(r'la([\d]+)', str(pmf)).group(1))
            roi_q = df_test.loc[df_test.label == la, :]
            self.assertEqual(len(roi_q), 1)
            roi = roi_q.iloc[0]
            m_acc = generate_file_accessor(pmf)
            self.assertEqual((roi.h, roi.w), m_acc.hw)
            patch_filenames.append(pmf.name)
        # make another RoiSet from just the data table, raw images, and (tight) patch masks
        test_roiset = RoiSet.deserialize(self.stack_ch_pa, where_ser, prefix='ref')
        self.assertEqual(ref_roiset.get_zmask().shape, test_roiset.get_zmask().shape,)
        self.assertTrue((ref_roiset.get_zmask() == test_roiset.get_zmask()).all())
        self.assertTrue(np.all(test_roiset.get_df().label == ref_roiset.get_df().label))
        cols = ['label', 'y1', 'y0', 'x1', 'x0', 'zi']
        self.assertTrue((test_roiset.get_df()[cols] == ref_roiset.get_df()[cols]).all().all())
        # re-serialize and check that patch masks are the same
        where_dser = output_path / 'deserialize'
        test_roiset.serialize(where_dser, prefix='test')
        for fr in patch_filenames:
            pr = (where_ser / 'tight_patch_masks' / fr)
            self.assertTrue(pr.exists())
            pt = (where_dser / 'tight_patch_masks' / fr.replace('ref', 'test'))
            self.assertTrue(pt.exists())
            r_acc = generate_file_accessor(pr)
            t_acc = generate_file_accessor(pt)
            self.assertTrue(np.all(r_acc.data == t_acc.data))