From 97d6b156174ae2e48682987843a3c1d2078c55ae Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 20 Sep 2024 09:56:06 +0200
Subject: [PATCH] Preparing to make binary mask optional in case of object
 detection result

---
 model_server/base/roiset.py | 36 ++++++++++++++++++++++++------------
 tests/base/test_roiset.py   | 16 ++++++++++------
 2 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index f529da22..f624460f 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -216,12 +216,13 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame
     :param acc_raw: accessor to raw image data
     :param acc_obj_ids: accessor to map of object IDs
     :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
-    # :param deproject: assign object's z-position based on argmax of raw data if True
+
     :return: pd.DataFrame
     """
     # build dataframe of objects, assign z index to each object
 
     # TODO: don't assume that channel 0 is the basis of z-argmax
+    # TODO: :param deproject: assign object's z-position based on argmax of raw data if True
     if acc_obj_ids.nz == 1:  # deproject objects' z-coordinates from argmax of raw image
         df = pd.DataFrame(regionprops_table(
             acc_obj_ids.data_xy,
@@ -246,6 +247,7 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame
         cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy
         return cropped
 
+
     df['binary_mask'] = df.apply(
         _make_binary_mask,
         axis=1,
@@ -314,6 +316,7 @@ def safe_add(a, g, b):
     ).astype(a.dtype)
 
 def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor:
+    # TODO: generate rectangular masks if running without segmentation
     id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16')
     if 'binary_mask' not in df.columns:
         raise MissingSegmentationError('RoiSet dataframe does not contain segmentation')
@@ -405,17 +408,20 @@ class RoiSet(object):
             params.get('expand_box_by', 0)
         )
 
-        def _make_binary_mask(r):
-            # TODO: make square mask array
-            # acc = InMemoryDataAccessor(acc_obj_ids.data == r.label)
-            # cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy
-            return cropped
-
-        df['binary_mask'] = df.apply(
-            _make_binary_mask,
-            axis=1,
-            result_type='reduce',
-        )
+        # TODO: don't even make binary_mask column in obj det case
+
+        # def _make_binary_mask(r):
+        #     # acc = InMemoryDataAccessor(acc_obj_ids.data == r.label)
+        #     # cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy
+        #     return cropped
+        #
+        #
+        #
+        # df['binary_mask'] = df.apply(
+        #     _make_binary_mask,
+        #     axis=1,
+        #     result_type='reduce',
+        # )
 
         return RoiSet(acc_raw, df, params)
 
@@ -639,7 +645,10 @@ class RoiSet(object):
         return InMemoryDataAccessor(om)
 
     def export_dataframe(self, csv_path: Path) -> str:
+        # TODO: move this inside of .serialize()
         csv_path.parent.mkdir(parents=True, exist_ok=True)
+
+        # TODO: suppress errors or check if binary_mask doesn't exist
         self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1).to_csv(csv_path, index=False)
         return csv_path.name
 
@@ -961,6 +970,9 @@ class RoiSet(object):
 
         return self._df.apply(_poly_from_mask, axis=1)
 
+    @property
+    def contains_segmentation(self):
+        return 'binary_mask' in self._df.columns
 
     @property
     def acc_obj_ids(self):
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 6794dac2..afc75f35 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -623,6 +623,9 @@ class TestRoiSetSerialization(unittest.TestCase):
             t_acc = generate_file_accessor(pt)
             self.assertTrue(np.all(r_acc.data == t_acc.data))
 
+        self.assertTrue(ref_roiset.contains_segmentation)
+        self.assertTrue(ref_roiset.contains_segmentation)
+
     def test_create_roiset_from_bounding_boxes(self):
         from skimage.measure import label, regionprops, regionprops_table
 
@@ -640,15 +643,16 @@ class TestRoiSetSerialization(unittest.TestCase):
         table['h'] = table['y1'] - table['y0']
 
 
-        # bbox =
-        self.assertTrue(False)
-        # RoiSet.from_bounding_boxes(
-        #     self.stack_ch_pa,
-        #
-        # )
+        roiset = RoiSet.from_bounding_boxes(
+            self.stack_ch_pa,
+
+        )
 
         # test segments reside in bounding boxes
 
+        self.assertFalse(roiset.contains_segmentation)
+        self.assertTrue(False)
+
 class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def test_compute_polygons(self):
-- 
GitLab