From 8f33b530a85f23a6fab71f3f055a2995504d987b Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 13 Sep 2024 16:41:47 +0200
Subject: [PATCH] Still implementing generation of RoiSet from bounding boxes
 alone

---
 model_server/base/roiset.py | 47 +++++++++++++++++++++++++++----------
 tests/base/test_roiset.py   | 26 ++++++++++++++++++++
 2 files changed, 61 insertions(+), 12 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 60cb5912..f529da22 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -221,6 +221,7 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame
     """
     # build dataframe of objects, assign z index to each object
 
+    # TODO: don't assume that channel 0 is the basis of z-argmax
     if acc_obj_ids.nz == 1:  # deproject objects' z-coordinates from argmax of raw image
         df = pd.DataFrame(regionprops_table(
             acc_obj_ids.data_xy,
@@ -328,7 +329,6 @@ def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor:
 
 class RoiSet(object):
 
-    # TODO: __init__ to take bounding boxes e.g. from obj det model; flag if overlaps are allowed
     def __init__(
             self,
             acc_raw: GenericImageDataAccessor,
@@ -383,17 +383,40 @@ class RoiSet(object):
     @staticmethod
     def from_bounding_boxes(
         acc_raw: GenericImageDataAccessor,
-        yxhw_list: List,
+        bbox_df: pd.DataFrame,
         params: RoiSetMetaParams = RoiSetMetaParams()
     ):
-        df = pd.DataFrame([
-            {
-                'y0': yxhw[0],
-                'y1': yxhw[0] + yxhw[2],
-                'x0': yxhw[1],
-                'x1': yxhw[1] + yxhw[3],
-            } for yxhw in yxhw_list
-        ])
+
+        # deproject if zi is not provided
+        if 'zi' not in bbox_df.columns:
+            def _slice_zmax(r):
+                acc_raw.crop_hw(
+                    r.y0,
+                    r.x0,
+                    r.y1 - r.y0,
+                    r.x1 - r.x0
+                )
+                zmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
+            bbox_df['zi'] = 0
+
+        df = df_insert_slices(
+            bbox_df,
+            acc_raw.shape_dict,
+            params.get('expand_box_by', 0)
+        )
+
+        def _make_binary_mask(r):
+            # TODO: make square mask array
+            # acc = InMemoryDataAccessor(acc_obj_ids.data == r.label)
+            # cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy
+            return cropped
+
+        df['binary_mask'] = df.apply(
+            _make_binary_mask,
+            axis=1,
+            result_type='reduce',
+        )
+
         return RoiSet(acc_raw, df, params)
 
 
@@ -945,14 +968,14 @@ class RoiSet(object):
 
     # TODO: make this work with obj det dataset
     @staticmethod
-    def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''):
+    def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix='') -> Self:
         """
         Create an RoiSet object from saved files and an image accessor
         :param acc_raw: accessor to image that contains ROIs
         :param where: path to directory containing RoiSet serialization files, namely dataframe.csv and a subdirectory
             named tight_patch_masks
         :param prefix: starting prefix of patch mask filenames
-        :return:
+        :return: RoiSet object
         """
         df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']]
 
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 75a9b775..6794dac2 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -623,6 +623,32 @@ class TestRoiSetSerialization(unittest.TestCase):
             t_acc = generate_file_accessor(pt)
             self.assertTrue(np.all(r_acc.data == t_acc.data))
 
+    def test_create_roiset_from_bounding_boxes(self):
+        from skimage.measure import label, regionprops, regionprops_table
+
+
+        mask = self.seg_mask_3d.data_xyz
+        labels = label(mask)
+        table = pd.DataFrame(
+            regionprops_table(labels)
+        ).rename(
+            columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'zi', 'bbox-3': 'y1', 'bbox-4': 'x1'}
+        ).drop(
+            columns=['bbox-5']
+        )
+        table['w'] = table['x1'] - table['x0']
+        table['h'] = table['y1'] - table['y0']
+
+
+        # bbox =
+        self.assertTrue(False)
+        # RoiSet.from_bounding_boxes(
+        #     self.stack_ch_pa,
+        #
+        # )
+
+        # test segments reside in bounding boxes
+
 class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def test_compute_polygons(self):
-- 
GitLab