diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 9c4abafc574e8210f0419ff57d7851930997b313..d92279b4bdabc665ee4da05b31369847245cf27a 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -802,13 +802,10 @@ class RoiSet(object): def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> pd.DataFrame: patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded).copy() + ext = 'tif' if is_df_3d(patches_df) else 'png' def _export_patch_mask(roi): patch = InMemoryDataAccessor.from_mono(roi.patch_mask) - if patch.nz == 1: - ext = 'png' - else: - ext = 'tif' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' write_accessor_data_to_file(where / fname, patch) return fname @@ -1199,16 +1196,17 @@ class RoiSet(object): :param prefix: starting prefix of patch mask filenames :return: RoiSet object """ - df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] + df = pd.read_csv(where / 'dataframe' / (prefix + '.csv')) pa_masks = where / 'tight_patch_masks' + is_3d = is_df_3d(df) + ext = 'tif' if is_3d else 'png' if pa_masks.exists(): # import segmentation masks def _read_binary_mask(r): - ext = 'png' fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' try: ma_acc = generate_file_accessor(pa_masks / fname) - if is_df_3d(df): + if is_3d: mask_data = ma_acc.data_yxz / ma_acc.dtype_max else: mask_data = ma_acc.data_yx / ma_acc.dtype_max diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index a23c4af63b83c7c53ff457b8a8b9e897864c53b3..f0a099f2005a30e2f5a15c4d8eb5b440db67ff53 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -251,8 +251,6 @@ class TestRoiSet3dProducts(unittest.TestCase): where = output_path / 'run_exports_mono_3d' - # TODO: test serialization/deserialization of 3d patches - def setUp(self) -> None: # set up test raw data and segmentation from file self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) @@ -317,6 +315,20 @@ class TestRoiSet3dProducts(unittest.TestCase): acc = generate_file_accessor(self.where / res['labels_overlay']) self.assertGreater(acc.nz, 1) + def test_serialize_and_deserialize_3d_patche(self): + ref_roiset = self.test_create_roiset_from_3d_obj_ids() + ref_roiset.serialize(self.where / 'serialize', prefix='ref') + + # make another RoiSet from just the data table, raw images, and (tight) patch masks + test_roiset = RoiSet.deserialize(self.stack_ch_pa, self.where / 'serialize', prefix='ref') + self.assertEqual(ref_roiset.get_zmask().shape, test_roiset.get_zmask().shape, ) + self.assertTrue((ref_roiset.get_zmask() == test_roiset.get_zmask()).all()) + self.assertTrue( + np.all( + test_roiset.get_df().label.unique() == ref_roiset.get_df().label.unique() + ) + ) + class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def setUp(self) -> None: