From 29bbcf6f8b16aacd3dd6a0f413afe889e7411de5 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 23 Sep 2024 11:24:15 +0200
Subject: [PATCH] Implemented and tested serialize/deserialize methods on
 object detection RoiSets

---
 model_server/base/roiset.py | 77 ++++++++++++++++++++++---------------
 tests/base/test_roiset.py   | 18 +++++++--
 2 files changed, 62 insertions(+), 33 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 3e99eaef..4eeda490 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -427,10 +427,11 @@ class RoiSet(object):
         bbox_df['x0'] = bbox_df['x']
         bbox_df['y1'] = bbox_df['y0'] + bbox_df['h']
         bbox_df['x1'] = bbox_df['x0'] + bbox_df['w']
+        bbox_df['label'] = bbox_df.index
 
 
         df = df_insert_slices(
-            bbox_df[['y0', 'x0', 'y1', 'x1', 'zi']],
+            bbox_df[['y0', 'x0', 'y1', 'x1', 'zi', 'label']],
             acc_raw.shape_dict,
             params.expand_box_by,
         )
@@ -924,7 +925,7 @@ class RoiSet(object):
 
         return record
 
-    def serialize(self, where: Path, prefix='') -> dict:
+    def serialize(self, where: Path, prefix='roiset') -> dict:
         """
         Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks
         :param where: path of directory in which to write files
@@ -932,18 +933,20 @@ class RoiSet(object):
         :return: nested dict of Path objects describing the locations of export products
         """
         record = {}
-        df_exp = self.export_patch_masks(
-            where / 'tight_patch_masks',
-            prefix=prefix,
-            pad_to=None,
-            expanded=False
-        )
-        # record patch masks paths to dataframe, then save static columns to CSV
-        se_pa = df_exp.patch_mask_path.apply(
-            lambda x: str(Path('tight_patch_masks') / x)
-        ).rename('tight_patch_masks_path')
+        if not self._df.binary_mask.apply(lambda x: np.all(x)).all():  # binary masks aren't just all True
+            df_exp = self.export_patch_masks(
+                where / 'tight_patch_masks',
+                prefix=prefix,
+                pad_to=None,
+                expanded=False
+            )
+            # record patch masks paths to dataframe, then save static columns to CSV
+            se_pa = df_exp.patch_mask_path.apply(
+                lambda x: str(Path('tight_patch_masks') / x)
+            ).rename('tight_patch_masks_path')
+            self._df = self._df.join(se_pa)
+            record['tight_patch_masks'] = list(se_pa)
 
-        self._df = self._df.join(se_pa)
         csv_path = where / 'dataframe' / (prefix + '.csv')
         csv_path.parent.mkdir(parents=True, exist_ok=True)
         self._df.drop(
@@ -952,9 +955,10 @@ class RoiSet(object):
         ).to_csv(csv_path, index=False)
 
         record['dataframe'] = str(Path('dataframe') / csv_path.name)
-        record['tight_patch_masks'] = list(se_pa)
+
         return record
 
+
     def get_polygons(self, poly_threshold=0, dilation_radius=1) -> pd.DataFrame:
         self.coordinates_ = """
         Fit polygons to all object boundaries in the RoiSet
@@ -994,9 +998,8 @@ class RoiSet(object):
     def acc_obj_ids(self):
         return make_object_ids_from_df(self._df, self.acc_raw.shape_dict)
 
-    # TODO: make this work with obj det dataset
     @staticmethod
-    def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix='') -> Self:
+    def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self:
         """
         Create an RoiSet object from saved files and an image accessor
         :param acc_raw: accessor to image that contains ROIs
@@ -1006,20 +1009,34 @@ class RoiSet(object):
         :return: RoiSet object
         """
         df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']]
-
-        def _read_binary_mask(r):
-            ext = 'png'
-            fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}'
-            try:
-                ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname)
-                assert ma_acc.chroma == 1 and ma_acc.nz == 1
-                mask_data = ma_acc.data_xy / np.iinfo(ma_acc.data.dtype).max
-                return mask_data
-            except Exception as e:
-                raise DeserializeRoiSet(e)
-        df['binary_mask'] = df.apply(_read_binary_mask, axis=1)
-        id_mask = make_object_ids_from_df(df, acc_raw.shape_dict)
-        return RoiSet.from_object_ids(acc_raw, id_mask)
+        pa_masks = where / 'tight_patch_masks'
+
+        if pa_masks.exists():  # import segmentation masks
+            def _read_binary_mask(r):
+                ext = 'png'
+                fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}'
+                try:
+                    ma_acc = generate_file_accessor(pa_masks / fname)
+                    assert ma_acc.chroma == 1 and ma_acc.nz == 1
+                    mask_data = ma_acc.data_xy / np.iinfo(ma_acc.data.dtype).max
+                    return mask_data
+                except Exception as e:
+                    raise DeserializeRoiSet(e)
+
+            df['binary_mask'] = df.apply(_read_binary_mask, axis=1)
+            id_mask = make_object_ids_from_df(df, acc_raw.shape_dict)
+            return RoiSet.from_object_ids(acc_raw, id_mask)
+
+        else:  # assume bounding boxes only
+            df['y'] = df['y0']
+            df['x'] = df['x0']
+            df['h'] = df['y1'] - df['y0']
+            df['w'] = df['x1'] - df['x0']
+            return RoiSet.from_bounding_boxes(
+                acc_raw,
+                df[['y', 'x', 'h', 'w']].to_dict(orient='records'),
+                list(df['zi'])
+            )
 
 
 
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 9de818ee..93e96df7 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -637,7 +637,6 @@ class TestRoiSetObjectDetection(unittest.TestCase):
     def test_create_roiset_from_bounding_boxes(self):
         from skimage.measure import label, regionprops, regionprops_table
 
-
         mask = self.seg_mask_3d
         labels = label(mask.data_xyz, connectivity=3)
         table = pd.DataFrame(
@@ -651,12 +650,12 @@ class TestRoiSetObjectDetection(unittest.TestCase):
         table['h'] = table['y1'] - table['y']
         bboxes = table[['y', 'x', 'h', 'w']].to_dict(orient='records')
 
-
         roiset_bbox = RoiSet.from_bounding_boxes(self.stack_ch_pa, bboxes)
-
+        self.assertTrue('label' in roiset_bbox.get_df().columns)
         patches_bbox = roiset_bbox.get_patches_acc()
         self.assertEqual(len(table), patches_bbox.count)
 
+
         # roiset w/ seg for comparison
         roiset_seg = RoiSet.from_binary_mask(self.stack_ch_pa, mask, allow_3d=True)
         patches_seg = roiset_seg.get_patches_acc()
@@ -666,6 +665,19 @@ class TestRoiSetObjectDetection(unittest.TestCase):
         for i in range(0, roiset_seg.count):
             self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape)
 
+        # test that serialization does not write patch masks
+        roiset_ser_path = output_path / 'roiset_from_bbox'
+        dd = roiset_bbox.serialize(roiset_ser_path)
+        self.assertTrue('tight_patch_masks' not in dd.keys())
+        self.assertFalse((roiset_ser_path / 'tight_patch_masks').exists())
+
+        # test that deserialized RoiSet matches the original
+        roiset_des = RoiSet.deserialize(self.stack_ch_pa, roiset_ser_path)
+        self.assertEqual(roiset_des.count, roiset_bbox.count)
+        for i in range(0, roiset_des.count):
+            self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape)
+        self.assertTrue((roiset_bbox.get_zmask() == roiset_des.get_zmask()).all())
+
 
 class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
-- 
GitLab