From 29bbcf6f8b16aacd3dd6a0f413afe889e7411de5 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 23 Sep 2024 11:24:15 +0200 Subject: [PATCH] Implemented and tested serialize/deserialize methods on object detection RoiSets --- model_server/base/roiset.py | 77 ++++++++++++++++++++++--------------- tests/base/test_roiset.py | 18 +++++++-- 2 files changed, 62 insertions(+), 33 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 3e99eaef..4eeda490 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -427,10 +427,11 @@ class RoiSet(object): bbox_df['x0'] = bbox_df['x'] bbox_df['y1'] = bbox_df['y0'] + bbox_df['h'] bbox_df['x1'] = bbox_df['x0'] + bbox_df['w'] + bbox_df['label'] = bbox_df.index df = df_insert_slices( - bbox_df[['y0', 'x0', 'y1', 'x1', 'zi']], + bbox_df[['y0', 'x0', 'y1', 'x1', 'zi', 'label']], acc_raw.shape_dict, params.expand_box_by, ) @@ -924,7 +925,7 @@ class RoiSet(object): return record - def serialize(self, where: Path, prefix='') -> dict: + def serialize(self, where: Path, prefix='roiset') -> dict: """ Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks :param where: path of directory in which to write files @@ -932,18 +933,20 @@ class RoiSet(object): :return: nested dict of Path objects describing the locations of export products """ record = {} - df_exp = self.export_patch_masks( - where / 'tight_patch_masks', - prefix=prefix, - pad_to=None, - expanded=False - ) - # record patch masks paths to dataframe, then save static columns to CSV - se_pa = df_exp.patch_mask_path.apply( - lambda x: str(Path('tight_patch_masks') / x) - ).rename('tight_patch_masks_path') + if not self._df.binary_mask.apply(lambda x: np.all(x)).all(): # binary masks aren't just all True + df_exp = self.export_patch_masks( + where / 'tight_patch_masks', + prefix=prefix, + pad_to=None, + expanded=False + ) + # record patch masks paths to dataframe, then save static columns to CSV + se_pa = df_exp.patch_mask_path.apply( + lambda x: str(Path('tight_patch_masks') / x) + ).rename('tight_patch_masks_path') + self._df = self._df.join(se_pa) + record['tight_patch_masks'] = list(se_pa) - self._df = self._df.join(se_pa) csv_path = where / 'dataframe' / (prefix + '.csv') csv_path.parent.mkdir(parents=True, exist_ok=True) self._df.drop( @@ -952,9 +955,10 @@ class RoiSet(object): ).to_csv(csv_path, index=False) record['dataframe'] = str(Path('dataframe') / csv_path.name) - record['tight_patch_masks'] = list(se_pa) + return record + def get_polygons(self, poly_threshold=0, dilation_radius=1) -> pd.DataFrame: self.coordinates_ = """ Fit polygons to all object boundaries in the RoiSet @@ -994,9 +998,8 @@ class RoiSet(object): def acc_obj_ids(self): return make_object_ids_from_df(self._df, self.acc_raw.shape_dict) - # TODO: make this work with obj det dataset @staticmethod - def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix='') -> Self: + def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self: """ Create an RoiSet object from saved files and an image accessor :param acc_raw: accessor to image that contains ROIs @@ -1006,20 +1009,34 @@ class RoiSet(object): :return: RoiSet object """ df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] - - def _read_binary_mask(r): - ext = 'png' - fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' - try: - ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname) - assert ma_acc.chroma == 1 and ma_acc.nz == 1 - mask_data = ma_acc.data_xy / np.iinfo(ma_acc.data.dtype).max - return mask_data - except Exception as e: - raise DeserializeRoiSet(e) - df['binary_mask'] = df.apply(_read_binary_mask, axis=1) - id_mask = make_object_ids_from_df(df, acc_raw.shape_dict) - return RoiSet.from_object_ids(acc_raw, id_mask) + pa_masks = where / 'tight_patch_masks' + + if pa_masks.exists(): # import segmentation masks + def _read_binary_mask(r): + ext = 'png' + fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' + try: + ma_acc = generate_file_accessor(pa_masks / fname) + assert ma_acc.chroma == 1 and ma_acc.nz == 1 + mask_data = ma_acc.data_xy / np.iinfo(ma_acc.data.dtype).max + return mask_data + except Exception as e: + raise DeserializeRoiSet(e) + + df['binary_mask'] = df.apply(_read_binary_mask, axis=1) + id_mask = make_object_ids_from_df(df, acc_raw.shape_dict) + return RoiSet.from_object_ids(acc_raw, id_mask) + + else: # assume bounding boxes only + df['y'] = df['y0'] + df['x'] = df['x0'] + df['h'] = df['y1'] - df['y0'] + df['w'] = df['x1'] - df['x0'] + return RoiSet.from_bounding_boxes( + acc_raw, + df[['y', 'x', 'h', 'w']].to_dict(orient='records'), + list(df['zi']) + ) diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 9de818ee..93e96df7 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -637,7 +637,6 @@ class TestRoiSetObjectDetection(unittest.TestCase): def test_create_roiset_from_bounding_boxes(self): from skimage.measure import label, regionprops, regionprops_table - mask = self.seg_mask_3d labels = label(mask.data_xyz, connectivity=3) table = pd.DataFrame( @@ -651,12 +650,12 @@ class TestRoiSetObjectDetection(unittest.TestCase): table['h'] = table['y1'] - table['y'] bboxes = table[['y', 'x', 'h', 'w']].to_dict(orient='records') - roiset_bbox = RoiSet.from_bounding_boxes(self.stack_ch_pa, bboxes) - + self.assertTrue('label' in roiset_bbox.get_df().columns) patches_bbox = roiset_bbox.get_patches_acc() self.assertEqual(len(table), patches_bbox.count) + # roiset w/ seg for comparison roiset_seg = RoiSet.from_binary_mask(self.stack_ch_pa, mask, allow_3d=True) patches_seg = roiset_seg.get_patches_acc() @@ -666,6 +665,19 @@ class TestRoiSetObjectDetection(unittest.TestCase): for i in range(0, roiset_seg.count): self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape) + # test that serialization does not write patch masks + roiset_ser_path = output_path / 'roiset_from_bbox' + dd = roiset_bbox.serialize(roiset_ser_path) + self.assertTrue('tight_patch_masks' not in dd.keys()) + self.assertFalse((roiset_ser_path / 'tight_patch_masks').exists()) + + # test that deserialized RoiSet matches the original + roiset_des = RoiSet.deserialize(self.stack_ch_pa, roiset_ser_path) + self.assertEqual(roiset_des.count, roiset_bbox.count) + for i in range(0, roiset_des.count): + self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape) + self.assertTrue((roiset_bbox.get_zmask() == roiset_des.get_zmask()).all()) + class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): -- GitLab