import unittest import numpy as np from pathlib import Path import pandas as pd from model_server.conf.testing import output_path, roiset_test_data from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams from model_server.base.roiset import _get_label_ids, RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import DummyInstanceSegmentationModel class BaseTestRoiSetMonoProducts(object): def setUp(self) -> None: # set up test raw data and segmentation from file self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel']) self.seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path']) class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def _make_roi_set(self, mask_type='boxes', **kwargs): id_map = _get_label_ids(self.seg_mask) roiset = RoiSet( self.stack_ch_pa, id_map, params=RoiSetMetaParams( mask_type=mask_type, filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), expand_box_by=(64, 2) ) ) return roiset def test_roi_mask_shape(self, **kwargs): roiset = self._make_roi_set(**kwargs) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) self.assertTrue(zmask_acc.is_mask()) # assert dimensionality of zmask self.assertEqual(zmask_acc.nz, roiset.acc_raw.nz) self.assertEqual(zmask_acc.chroma, 1) write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) # mask values are not just all True or all False self.assertTrue(np.any(zmask)) self.assertFalse(np.all(zmask)) # assert non-trivial meta info in boxes self.assertGreater(roiset.count, 1) sh = roiset.get_df().iloc[1]['mask'].shape ar = roiset.get_df().iloc[1]['area'] self.assertGreaterEqual(sh[0] * sh[1], ar) def test_roiset_from_non_zstacks(self, **kwargs): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) self.assertEqual(acc_zstack_slice.nz, 1) id_map = _get_label_ids(self.seg_mask) roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes')) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) self.assertTrue(zmask_acc.is_mask()) def test_slices_are_valid(self): roiset = self._make_roi_set() for s in roiset.get_slices(): ebb = roiset.acc_raw.data[s] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) def test_rel_slices_are_valid(self): roiset = self._make_roi_set() for roi in roiset: ebb = roiset.acc_raw.data[roi.slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) rbb = ebb[roi.relative_slice] self.assertEqual(len(rbb.shape), 4) self.assertTrue(np.all([si >= 1 for si in rbb.shape])) def test_make_2d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / '2d_patches', draw_bounding_box=True, ) self.assertGreaterEqual(len(files), 1) def test_make_3d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / '3d_patches', make_3d=True) self.assertGreaterEqual(len(files), 1) def test_export_annotated_zstack(self): roiset = self._make_roi_set() file = roiset.export_annotated_zstack( output_path / 'annotated_zstack', ) result = generate_file_accessor(Path(file['location']) / file['filename']) self.assertEqual(result.shape, roiset.acc_raw.shape) def test_flatten_image(self): id_map = _get_label_ids(self.seg_mask) roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes')) df = roiset.get_df() from model_server.base.roiset import project_stack_from_focal_points img = project_stack_from_focal_points( df['centroid-0'].to_numpy(), df['centroid-1'].to_numpy(), df['zi'].to_numpy(), self.stack, degree=4, ) self.assertEqual(img.shape[0:2], self.stack.shape[0:2]) write_accessor_data_to_file( output_path / 'flattened.tif', InMemoryDataAccessor(img) ) def test_make_binary_masks(self): roiset = self._make_roi_set() files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) self.assertGreaterEqual(len(files), 1) def test_classify_by(self): roiset = self._make_roi_set() roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) return roiset def test_export_object_classes(self): record = self.test_classify_by().run_exports( output_path / 'object_class_maps', 0, 'obmap', RoiSetExportParams(object_classes=True) ) opa = record['object_classes_dummy_class'] self.assertTrue(Path(opa).exists()) acc = generate_file_accessor(opa) self.assertTrue(all(np.unique(acc.data) == [0, 1])) def test_raw_patches_are_correct_shape(self): roiset = self._make_roi_set() patches = roiset.get_raw_patches() np, h, w, nc, nz = patches.shape self.assertEqual(np, roiset.count) self.assertEqual(nc, roiset.acc_raw.chroma) def test_patch_masks_are_correct_shape(self): roiset = self._make_roi_set() patch_masks = roiset.get_patch_masks() np, h, w, nc, nz = patch_masks.shape self.assertEqual(np, roiset.count) self.assertEqual(nc, 1) class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def setUp(self) -> None: super().setUp() id_map = _get_label_ids(self.seg_mask) self.roiset = RoiSet( self.stack, id_map, params=RoiSetMetaParams( expand_box_by=(128, 2), mask_type='boxes', filters={'area': {'min': 1e3, 'max': 1e4}}, ) ) def test_multichannel_to_mono_2d_patches(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches', white_channel=3, draw_bounding_box=True, ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 1) def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox', white_channel=3, draw_bounding_box=True, bounding_box_channel=1, ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_bbox(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_bbox', white_channel=4, rgb_overlay_channels=(3, None, None), draw_mask=True, mask_channel=0, rgb_overlay_weights=(0.1, 1.0, 1.0) ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_contour(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_contour', rgb_overlay_channels=(3, None, None), draw_contour=True, contour_channel=1, rgb_overlay_weights=(0.1, 1.0, 1.0) ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 3) self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black def test_multichannel_to_multichannel_tif_patches(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'multichannel_tif_patches', ) result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename']) self.assertEqual(result.chroma, 5) def test_multichannel_annotated_zstack(self): file = self.roiset.export_annotated_zstack( output_path / 'multichannel' / 'annotated_zstack', 'test_multichannel_annotated_zstack', ) result = generate_file_accessor(Path(file['location']) / file['filename']) self.assertEqual(result.chroma, self.stack.chroma) self.assertEqual(result.nz, self.stack.nz) def test_export_single_channel_annotated_zstack(self): file = self.roiset.export_annotated_zstack( output_path / 'annotated_zstack', channel=3, ) result = generate_file_accessor(Path(file['location']) / file['filename']) self.assertEqual(result.hw, self.roiset.acc_raw.hw) self.assertEqual(result.nz, self.roiset.acc_raw.nz) self.assertEqual(result.chroma, 1) class TestRoiSetFromZmask(unittest.TestCase): def setUp(self) -> None: # set up test raw data and segmentation from file self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['segmentation_channel']) self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d']) @staticmethod def _label_is_2d(id_map, la): # single label's zmask has same counts as its MIP mask_3d = (id_map == la) mask_mip = mask_3d.max(axis=-1) return mask_3d.sum() == mask_mip.sum() def test_id_map_connects_z(self): id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True) labels = np.unique(id_map.data)[1:] is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) self.assertFalse(is_2d) def test_id_map_disconnects_z(self): id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) labels = np.unique(id_map.data)[1:] is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) self.assertTrue(is_2d) def test_create_roiset_from_3d_obj_ids(self): id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) self.assertEqual(self.stack_ch_pa.shape, id_map.shape) roiset = RoiSet( self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='contours') ) self.assertEqual(roiset.count, id_map.data.max()) self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) def test_create_roiset_from_2d_obj_ids(self): id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False) self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3]) self.assertEqual(id_map.nz, 1) roiset = RoiSet( self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='contours') ) self.assertEqual(roiset.count, id_map.data.max()) self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) return roiset def test_create_roiset_from_df_and_patch_masks(self): ref_roiset = self.test_create_roiset_from_2d_obj_ids() res = ref_roiset.run_exports( output_path / 'roiset_from_3d', roiset_test_data['pipeline_params']['segmentation_channel'], 'ref', params=RoiSetExportParams(patch_masks={'pad_to': 256}, dataframe=True) ) where_df = output_path / 'roiset_from_3d' / 'dataframe' / 'ref.csv' self.assertTrue(where_df.exists()) df_test = pd.read_csv(where_df) # zmask = np.zeros((*self.stack.hw, 1, self.stack.nz), dtype=bool) print('hi') fn = output_path / 'roiset_from_3d' / 'patch_masks' / 'ref-la{:04d}-zi{:04d}.png' patch_masks = {} def _label_obj(r): sl = np.s_[r.ebb_y0:r.ebb_y1, r.ebb_x0:r.ebb_x1, :, r.zi:r.zi + 1] self.assertEqual(str(sl), r.slice) patch_masks[r.label] = generate_file_accessor(str(fn).format(r.label, r.zi)).data # zmask[sl] = True df_test.apply(lambda x: _label_obj(x), axis=1) roiset_test = RoiSet.from_df_and_patch_masks(self.stack, df_test, patch_masks) print('')