diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 20efb6569766e1bea0a3321660b752422aa4bba0..4e23d3f538c4f9b359636b8558bc2119e2022aec 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -65,8 +65,18 @@ class RoiSetExportParams(BaseModel): -def _get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor: - return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')) +def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False) -> InMemoryDataAccessor: + if allow_3d: + nda_la = label( + acc_seg_mask.data[:, :, 0, :] + ).astype('uint16') + return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) + else: + return InMemoryDataAccessor( + label( + acc_seg_mask.data[:, :, 0, :].max(axis=-1) + ).astype('uint16') + ) def _focus_metrics(): @@ -197,9 +207,31 @@ class RoiSet(object): params: RoiSetMetaParams = RoiSetMetaParams() ): - assert acc_obj_ids.nz == 1 + assert acc_obj_ids.nz == 1, 'Can only use this method with a 2D object identities map' + return RoiSet(acc_raw, acc_obj_ids, params) + + @staticmethod + def from_3d_obj_ids( + acc_raw: GenericImageDataAccessor, + acc_obj_ids: GenericImageDataAccessor, + params: RoiSetMetaParams = RoiSetMetaParams(), + ): + assert acc_obj_ids.nz > 1, 'Can only use this method with a 3D object identities map' return RoiSet(acc_raw, acc_obj_ids, params) + @staticmethod + def from_df_and_patch_masks( + acc_raw: GenericImageDataAccessor, + df: pd.DataFrame, + patch_masks: dict, # dict of ndarray, where key is integer label + ): + assert len(df) == len(patch_masks) + se_patch_masks = pd.Series(patch_masks) + # df_merged = pd.merge(df, se_patch_masks) + assert all(patch_masks.keys()) + assert df.apply(lambda x: x, axis=1) + assert False + @staticmethod def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: query_str = 'label > 0' # always true diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 3d07931ee087dba7eb56f416f232b16b89056161..97a13aaf75192c76a75ace8c8bcccf3678370a5d 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -67,6 +67,7 @@ roiset_test_data = { 'c': 5, 'z': 7, 'mask_path': root / 'zmask-test-stack-mask.tif', + 'mask_path_3d': root / 'zmask-test-stack-mask-3d.tif', }, 'pipeline_params': { 'segmentation_channel': 0, diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 02602506041d2999b480a827a8353071c77fa4bf..3b3d1fc402c6ceda6bd447282732d41725b42676 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -41,8 +41,8 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue(zmask_acc.is_mask()) # assert dimensionality of zmask - self.assertGreater(zmask_acc.shape_dict['Z'], 1) - self.assertEqual(zmask_acc.shape_dict['C'], 1) + self.assertEqual(zmask_acc.nz, roiset.acc_raw.nz) + self.assertEqual(zmask_acc.chroma, 1) write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc) # mask values are not just all True or all False @@ -253,3 +253,41 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(result.nz, self.roiset.acc_raw.nz) self.assertEqual(result.chroma, 1) +class TestRoiSetFromZmask(unittest.TestCase): + + def setUp(self) -> None: + # set up test raw data and segmentation from file + self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['segmentation_channel']) + self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d']) + + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True) + self.assertGreater(id_map.nz, 1) + + roiset = RoiSet.from_3d_obj_ids( + self.stack_ch_pa, + id_map, + params=RoiSetMetaParams( + mask_type='contours', + filters={'area': {'min': 1e3, 'max': 1e4}}, + ) + ) + self.roiset = roiset + self.zmask = InMemoryDataAccessor(roiset.get_zmask()) + + def test_id_map_connects_z(self): + id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True) + labels = np.unique(id_map.data)[1:] + def _label_is_2d(la): # single label's zmask has same counts as its MIP + mask_3d = (id_map.data == la) + mask_mip = mask_3d.max(axis=-1) + return mask_3d.sum() == mask_mip.sum() + is_2d = all([_label_is_2d(la) for la in labels]) + print([_label_is_2d(la) for la in labels]) + self.assertFalse(is_2d) + + def test_3d_zmask(self): + write_accessor_data_to_file(output_path / 'roiset_from_3d' / 'raw.tif', self.roiset.acc_raw) + write_accessor_data_to_file(output_path / 'roiset_from_3d' / 'ob_ids.tif', self.roiset.acc_obj_ids) + write_accessor_data_to_file(output_path / 'roiset_from_3d' / 'zmask.tif', self.zmask) +