import os import re import unittest import numpy as np from pathlib import Path import pandas as pd from model_server.conf.testing import output_path, roiset_test_data from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams from model_server.base.roiset import _get_label_ids, RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import DummyInstanceSegmentationModel class BaseTestRoiSetMonoProducts(object): def setUp(self) -> None: # set up test raw data and segmentation from file self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel']) self.seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path']) class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def _make_roi_set(self, mask_type='boxes', **kwargs): id_map = _get_label_ids(self.seg_mask) roiset = RoiSet( self.stack_ch_pa, id_map, params=RoiSetMetaParams( mask_type=mask_type, filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}), expand_box_by=(64, 2) ) ) return roiset def test_roi_mask_shape(self, **kwargs): roiset = self._make_roi_set(**kwargs) # all masks' bounding boxes are at least as big as ROI area for roi in roiset.get_df().itertuples(): self.assertEqual(roi.binary_mask.dtype, 'bool') sh = roi.binary_mask.shape self.assertEqual(sh, (roi.h, roi.w)) self.assertGreaterEqual(sh[0] * sh[1], roi.area) def test_roi_zmask(self, **kwargs): roiset = self._make_roi_set(**kwargs) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) self.assertTrue(zmask_acc.is_mask()) # assert dimensionality of zmask self.assertEqual(zmask_acc.nz, roiset.acc_raw.nz) self.assertEqual(zmask_acc.chroma, 1) write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) # mask values are not just all True or all False self.assertTrue(np.any(zmask)) self.assertFalse(np.all(zmask)) def test_roiset_from_non_zstacks(self, **kwargs): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) self.assertEqual(acc_zstack_slice.nz, 1) id_map = _get_label_ids(self.seg_mask) roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes')) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) self.assertTrue(zmask_acc.is_mask()) def test_slices_are_valid(self): roiset = self._make_roi_set() for s in roiset.get_slices(): ebb = roiset.acc_raw.data[s] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) def test_dataframe_and_mask_array_in_iterator(self): roiset = self._make_roi_set() for roi in roiset: ma = roi.binary_mask self.assertEqual(ma.dtype, 'bool') self.assertEqual(ma.shape, (roi.h, roi.w)) def test_rel_slices_are_valid(self): roiset = self._make_roi_set() for roi in roiset: ebb = roiset.acc_raw.data[roi.expanded_slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) rbb = ebb[roi.relative_slice] self.assertEqual(len(rbb.shape), 4) self.assertTrue(np.all([si >= 1 for si in rbb.shape])) def test_make_expanded_2d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / 'expanded_2d_patches', draw_bounding_box=True, expanded=True, pad_to=256, ) df = roiset.get_df() for f in files: acc = generate_file_accessor(f) la = int(re.search(r'la([\d]+)', str(f)).group(1)) roi_q = df.loc[df.label == la, :] self.assertEqual(len(roi_q), 1) self.assertEqual((256, 256), acc.hw) def test_make_tight_2d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / 'tight_2d_patches', draw_bounding_box=True, expanded=False ) df = roiset.get_df() for f in files: # all exported files are same shape as bounding boxes in RoiSet's datatable acc = generate_file_accessor(f) la = int(re.search(r'la([\d]+)', str(f)).group(1)) roi_q = df.loc[df.label == la, :] self.assertEqual(len(roi_q), 1) roi = roi_q.iloc[0] self.assertEqual((roi.h, roi.w), acc.hw) def test_make_expanded_3d_patches(self): roiset = self._make_roi_set() files = roiset.export_patches( output_path / '3d_patches', make_3d=True, expanded=True ) self.assertGreaterEqual(len(files), 1) for f in files: acc = generate_file_accessor(f) self.assertGreater(acc.nz, 1) def test_export_annotated_zstack(self): roiset = self._make_roi_set() file = roiset.export_annotated_zstack( output_path / 'annotated_zstack', ) result = generate_file_accessor(file) self.assertEqual(result.shape, roiset.acc_raw.shape) def test_flatten_image(self): id_map = _get_label_ids(self.seg_mask) roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes')) df = roiset.get_df() from model_server.base.roiset import project_stack_from_focal_points img = project_stack_from_focal_points( df['centroid-0'].to_numpy(), df['centroid-1'].to_numpy(), df['zi'].to_numpy(), self.stack, degree=4, ) self.assertEqual(img.shape[0:2], self.stack.shape[0:2]) write_accessor_data_to_file( output_path / 'flattened.tif', InMemoryDataAccessor(img) ) def test_make_binary_masks(self): roiset = self._make_roi_set() files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) self.assertGreaterEqual(len(files), 1) def test_classify_by(self): roiset = self._make_roi_set() roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) return roiset def test_export_object_classes(self): record = self.test_classify_by().run_exports( output_path / 'object_class_maps', 0, 'obmap', RoiSetExportParams(object_classes=True) ) opa = record['object_classes_dummy_class'] self.assertTrue(Path(opa).exists()) acc = generate_file_accessor(opa) self.assertTrue(all(np.unique(acc.data) == [0, 1])) def test_raw_patches_are_correct_shape(self): roiset = self._make_roi_set() patches = roiset.get_patches_acc() np, h, w, nc, nz = patches.shape self.assertEqual(np, roiset.count) self.assertEqual(nc, roiset.acc_raw.chroma) self.assertEqual(nz, 1) def test_patch_masks_are_correct_shape(self): roiset = self._make_roi_set() df_patch_masks = roiset.get_patch_masks() for roi in df_patch_masks.itertuples(): h, w, nc, nz = roi.patch_mask.shape self.assertEqual(nc, 1) self.assertEqual(nz, 1) self.assertEqual(h, roi.h) self.assertEqual(w, roi.w) class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def setUp(self) -> None: super().setUp() id_map = _get_label_ids(self.seg_mask) self.roiset = RoiSet( self.stack, id_map, params=RoiSetMetaParams( expand_box_by=(128, 2), mask_type='boxes', filters={'area': {'min': 1e3, 'max': 1e4}}, ) ) def test_multichannel_to_mono_2d_patches(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches', white_channel=3, draw_bounding_box=True, expanded=True, pad_to=256, ) result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 1) def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox', white_channel=3, draw_bounding_box=True, bounding_box_channel=1, expanded=True, pad_to=256, ) result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_bbox(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_bbox', white_channel=4, rgb_overlay_channels=(3, None, None), draw_mask=True, mask_channel=0, rgb_overlay_weights=(0.1, 1.0, 1.0), expanded=True, pad_to=256, ) result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_contour(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_contour', rgb_overlay_channels=(3, None, None), draw_contour=True, contour_channel=1, rgb_overlay_weights=(0.1, 1.0, 1.0), expanded=True, pad_to=256, ) result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 3) self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black def test_multichannel_to_multichannel_tif_patches(self): files = self.roiset.export_patches( output_path / 'multichannel' / 'multichannel_tif_patches', expanded=True, pad_to=256, ) result = generate_file_accessor(files[0]) self.assertEqual(result.chroma, 5) self.assertEqual(result.nz, 1) def test_multichannel_annotated_zstack(self): file = self.roiset.export_annotated_zstack( output_path / 'multichannel' / 'annotated_zstack', 'test_multichannel_annotated_zstack', expanded=True, pad_to=256, ) result = generate_file_accessor(file) self.assertEqual(result.chroma, self.stack.chroma) self.assertEqual(result.nz, self.stack.nz) def test_export_single_channel_annotated_zstack(self): file = self.roiset.export_annotated_zstack( output_path / 'annotated_zstack', channel=3, expanded=True, pad_to=256, ) result = generate_file_accessor(file) self.assertEqual(result.hw, self.roiset.acc_raw.hw) self.assertEqual(result.nz, self.roiset.acc_raw.nz) self.assertEqual(result.chroma, 1) class TestRoiSetFromZmask(unittest.TestCase): def setUp(self) -> None: # set up test raw data and segmentation from file self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['segmentation_channel']) self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d']) @staticmethod def _label_is_2d(id_map, la): # single label's zmask has same counts as its MIP mask_3d = (id_map == la) mask_mip = mask_3d.max(axis=-1) return mask_3d.sum() == mask_mip.sum() def test_id_map_connects_z(self): id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True) labels = np.unique(id_map.data)[1:] is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) self.assertFalse(is_2d) def test_id_map_disconnects_z(self): id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) labels = np.unique(id_map.data)[1:] is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) self.assertTrue(is_2d) def test_create_roiset_from_3d_obj_ids(self): id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) self.assertEqual(self.stack_ch_pa.shape, id_map.shape) roiset = RoiSet( self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='contours') ) self.assertEqual(roiset.count, id_map.data.max()) self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) def test_create_roiset_from_2d_obj_ids(self): id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False) self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3]) self.assertEqual(id_map.nz, 1) roiset = RoiSet( self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='contours') ) self.assertEqual(roiset.count, id_map.data.max()) self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) return roiset def test_create_roiset_from_df_and_patch_masks(self): ref_roiset = self.test_create_roiset_from_2d_obj_ids() res = ref_roiset.run_exports( output_path / 'roiset_from_3d', roiset_test_data['pipeline_params']['segmentation_channel'], 'ref', params=RoiSetExportParams() ) where_df = output_path / 'roiset_from_3d' / 'dataframe' / 'ref.csv' self.assertTrue(where_df.exists()) df_test = pd.read_csv(where_df) # check that patches are correct size where_patch_masks = output_path / 'roiset_from_3d' / 'tight_patch_masks' for pmf in where_patch_masks.iterdir(): self.assertTrue(pmf.suffix.upper() == '.PNG') la = int(re.search(r'la([\d]+)', str(pmf)).group(1)) roi_q = df_test.loc[df_test.label == la, :] self.assertEqual(len(roi_q), 1) roi = roi_q.iloc[0] m_acc = generate_file_accessor(pmf) self.assertEqual((roi.h, roi.w), m_acc.hw) # df_test = pd.read_csv(where_df) # # # zmask = np.zeros((*self.stack.hw, 1, self.stack.nz), dtype=bool) # print('hi') # # fn = output_path / 'roiset_from_3d' / 'patch_masks' / 'ref-la{:04d}-zi{:04d}.png' # patch_masks = {} # # def _label_obj(r): # sl = np.s_[r.ebb_y0:r.ebb_y1, r.ebb_x0:r.ebb_x1, :, r.zi:r.zi + 1] # self.assertEqual(str(sl), r.slice) # patch_masks[r.label] = generate_file_accessor(str(fn).format(r.label, r.zi)).data # # zmask[sl] = True # # df_test.apply(lambda x: _label_obj(x), axis=1) # # roiset_test = RoiSet.from_df_and_patch_masks(self.stack, df_test, patch_masks) # print('')