diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 86b6afe42c62b8865f1b7826f89501e0c8ecb754..ac25c94e4e3981cdde11e1bf4fa6baf4a9c38352 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -49,6 +49,16 @@ class GenericImageDataAccessor(ABC): nda = self.data.take(indices=carr, axis=self._ga('C')) return self._derived_accessor(nda) + + def get_zi(self, zi: int): + return self._derived_accessor( + self.data.take( + indices=[zi], + axis=self._ga('Z') + ) + ) + + def get_mono(self, channel: int, mip: bool = False): return self.get_channels([channel], mip=mip) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 68ee831a58a7088d74e8959e8e47e5734780ac6e..4b6f2d6f68ca6d717bcfb44fac5913bd6f3dd384 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -673,6 +673,28 @@ class RoiSet(object): record['tight_patch_masks'] = list(se_pa) return record + def serialize_coco(self, where: Path, prefix='') -> dict: + """ + Export the RoiSet according to the COCO seg standard + :param where: path of directory in which to write files + :param prefix: (optional) prefix + :return: nested dict of Path objects describing the locations of export products + """ + + # df_coco = self.get_df() + + df_p = self.get_patches(expanded=False) + + def _export_zi(df_zi): + zi = df_zi.zi.iat[0] + acc_zi = self.acc_raw.get_zi(zi) + return None + + + df_p.groupby('zi').apply(_export_zi) + + return {} + @staticmethod def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''): df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 014eda2c283ce9297f3084513abad66497303cbd..56a07860b735509419b36e60d53710615fda5a6f 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -58,6 +58,18 @@ class TestCziImageFileAccess(unittest.TestCase): sc = cf.get_mono(c, mip=True) self.assertEqual(sc.shape, (h, w, 1, 1)) + def test_get_zi(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + zi = 5 + cf = InMemoryDataAccessor(_random_int(h, w, nc, nz)) + sz = cf.get_zi(zi) + self.assertEqual(sz.shape_dict['Z'], 1) + + self.assertTrue(np.all(sz.data[:, :, :, 0] == cf.data[:, :, :, zi])) + def test_write_single_channel_tif(self): ch = 4 cf = CziImageFileAccessor(czifile['path']) diff --git a/tests/test_roiset.py b/tests/test_roiset.py index efc0779f067ebffba0acdb5c0f8ff1852eccef2b..d3ecb18ec99f2d04ea08485b22dbadedc9b0beeb 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -525,6 +525,7 @@ class TestRoiSetSerialization(unittest.TestCase): self.stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) self.stack_ch_pa = self.stack.get_mono(roiset_test_data['pipeline_params']['segmentation_channel']) self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d']) + self.seg_mask_2d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path']) @staticmethod def _label_is_2d(id_map, la): # single label's zmask has same counts as its MIP @@ -611,3 +612,11 @@ class TestRoiSetSerialization(unittest.TestCase): t_acc = generate_file_accessor(pt) self.assertTrue(np.all(r_acc.data == t_acc.data)) + def test_serialize_coco(self): + roiset = RoiSet.from_segmentation( + self.stack_ch_pa, + self.seg_mask_3d, + params=RoiSetMetaParams(mask_type='contours') + ) + roiset.serialize_coco(output_path / 'serialize_coco') + self.assertEqual(1, 0)