From 8254cad584be34f1aa48158e5aa2529388e8a95a Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 28 Mar 2024 18:37:26 +0100
Subject: [PATCH] Implementing alternative dataframe generation where
 z-positions are in object map, i.e. do not need to be de-projected from raw

---
 model_server/base/roiset.py | 41 ++++++++++++++++++++++++-----------
 tests/test_roiset.py        | 43 +++++++++++++++++++++++++------------
 2 files changed, 58 insertions(+), 26 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 236d742a..4607b72b 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -154,29 +154,46 @@ class RoiSet(object):
         return self._df.itertuples(name='Roi')
 
     @staticmethod
-    def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
+    def make_df(acc_raw, acc_obj_ids, expand_box_by, deproject=True) -> pd.DataFrame:
         """
         Build dataframe associate object IDs with summary stats
         :param acc_raw: accessor to raw image data
         :param acc_obj_ids: accessor to map of object IDs
         :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
+        :param deproject: assign object's z-position based on argmax of raw data if True
         :return: pd.DataFrame
         """
         # build dataframe of objects, assign z index to each object
-        argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
-        df = (
-            pd.DataFrame(
-                regionprops_table(
-                    acc_obj_ids.data[:, :, 0, 0],
-                    intensity_image=argmax,
-                    properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid')
+        if deproject:
+            assert acc_obj_ids.nz == 1, 'Can only deproject a 2D object identity map'
+            argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
+            df = (
+                pd.DataFrame(
+                    regionprops_table(
+                        acc_obj_ids.data[:, :, 0, 0],
+                        intensity_image=argmax,
+                        properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid')
+                    )
+                )
+                .rename(
+                    columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', }
                 )
             )
-            .rename(
-                columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', }
+            df['zi'] = df['intensity_mean'].round().astype('int')
+        else:
+            df = (
+                pd.DataFrame(
+                    regionprops_table(
+                        acc_obj_ids.data[:, :, 0, :],
+                        properties=('label', 'area', 'solidity', 'bbox', 'centroid')
+                    )
+                )
+                .rename(
+                    columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', }
+                )
             )
-        )
-        df['zi'] = df['intensity_mean'].round().astype('int')
+            df['zi'] = ... # from somewhere in obmap floor of avg z for each label...
+
 
         # compute expanded bounding boxes
         h, w, c, nz = acc_raw.shape
diff --git a/tests/test_roiset.py b/tests/test_roiset.py
index 532218b9..f8e2eb6c 100644
--- a/tests/test_roiset.py
+++ b/tests/test_roiset.py
@@ -261,20 +261,6 @@ class TestRoiSetFromZmask(unittest.TestCase):
         self.stack_ch_pa = self.stack.get_one_channel_data(roiset_test_data['pipeline_params']['segmentation_channel'])
         self.seg_mask_3d = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path_3d'])
 
-        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True)
-        self.assertGreater(id_map.nz, 1)
-
-        roiset = RoiSet.from_3d_obj_ids(
-            self.stack_ch_pa,
-            id_map,
-            params=RoiSetMetaParams(
-                mask_type='contours',
-                filters={'area': {'min': 1e3, 'max': 1e4}},
-            )
-        )
-        self.roiset = roiset
-        self.zmask = InMemoryDataAccessor(roiset.get_zmask())
-
     @staticmethod
     def _label_is_2d(id_map, la):  # single label's zmask has same counts as its MIP
         mask_3d = (id_map == la)
@@ -293,3 +279,32 @@ class TestRoiSetFromZmask(unittest.TestCase):
         is_2d = all([self._label_is_2d(id_map.data, la) for la in labels])
         self.assertTrue(is_2d)
 
+    def test_3d_zmask(self):
+        write_accessor_data_to_file(output_path / 'roiset_from_3d' / 'raw.tif', self.roiset.acc_raw)
+        # write_accessor_data_to_file(output_path / 'roiset_from_3d' / 'ob_ids.tif', self.roiset.acc_obj_ids)
+        write_accessor_data_to_file(output_path / 'roiset_from_3d' / 'zmask.tif', self.zmask)
+
+    def test_create_roiset_from_3d_obj_ids(self):
+        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
+        self.assertEqual(self.stack_ch_pa.shape, id_map.shape)
+
+        roiset = RoiSet.from_3d_obj_ids(
+            self.stack_ch_pa,
+            id_map,
+            params=RoiSetMetaParams(mask_type='contours')
+        )
+        self.assertEqual(roiset.count, id_map.data.max())
+        self.assertGreater(len(roiset.get_df()['zi'].unique()), 1)
+
+    def test_create_roiset_from_2d_obj_ids(self):
+        id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False)
+        self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3])
+        self.assertEqual(id_map.nz, 1)
+
+        roiset = RoiSet.from_2d_obj_ids(
+            self.stack_ch_pa,
+            id_map,
+            params=RoiSetMetaParams(mask_type='contours')
+        )
+        self.assertEqual(roiset.count, id_map.data.max())
+        self.assertGreater(len(roiset.get_df()['zi'].unique()), 1)
-- 
GitLab