From 3730f8c72c5cbecd6e87a3861a355c007a46ed6e Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 20 Sep 2024 11:35:06 +0200
Subject: [PATCH] Parameterized deprojection when creating RoiSet from binary
 mask, including option to pass a channel

---
 model_server/base/roiset.py | 39 ++++++++++++++++++++++++++++---------
 tests/base/test_roiset.py   |  8 +++++---
 2 files changed, 35 insertions(+), 12 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 9f7de716..ed8f44f8 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -52,6 +52,7 @@ class RoiFilter(BaseModel):
 class RoiSetMetaParams(BaseModel):
     filters: Union[RoiFilter, None] = None
     expand_box_by: List[int] = [128, 0]
+    deproject_channel: Union[int, None] = None
 
 
 class RoiSetExportParams(BaseModel):
@@ -210,23 +211,33 @@ def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Dat
     return dfbb
 
 
-def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
+def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_channel=None) -> pd.DataFrame:
     """
-    Build dataframe associate object IDs with summary stats
+    Build dataframe that associate object IDs with summary stats;
     :param acc_raw: accessor to raw image data
     :param acc_obj_ids: accessor to map of object IDs
     :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
-
+    :param deproject_channel: if objects' z-coordinates are not specified, compute them based on argmax of this channel
     :return: pd.DataFrame
     """
     # build dataframe of objects, assign z index to each object
 
-    # TODO: don't assume that channel 0 is the basis of z-argmax
-    # TODO: :param deproject: assign object's z-position based on argmax of raw data if True
-    if acc_obj_ids.nz == 1:  # deproject objects' z-coordinates from argmax of raw image
+    if acc_obj_ids.nz == 1 and acc_raw.nz > 1:
+
+        if deproject_channel is None or deproject_channel >= acc_raw.chroma or deproject_channel < 0:
+            if acc_raw.chroma == 1:
+                deproject_channel = 0
+            else:
+                raise NoDeprojectChannelSpecifiedError(
+                    f'When labeling objects, either their z-coordinates or a valid deprojection channel are required.'
+                )
+        acc_raw.get_mono(deproject_channel)
+
+        zi_map = acc_raw.get_mono(deproject_channel).get_z_argmax().data_xy.astype('uint16')
+        assert len(zi_map.shape) == 2
         df = pd.DataFrame(regionprops_table(
             acc_obj_ids.data_xy,
-            intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'),
+            intensity_image=zi_map,
             properties=('label', 'area', 'intensity_mean', 'bbox')
         )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'})
         df['zi'] = df['intensity_mean'].round().astype('int')
@@ -238,7 +249,11 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame
         )).rename(columns={
             'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1'
         })
-        df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax())
+
+        def _get_zi_from_label(la):
+            return acc_obj_ids.apply(lambda x: x == la).get_focus_vector().argmax()
+
+        df['zi'] = df['label'].apply(_get_zi_from_label)
 
     df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by)
 
@@ -374,7 +389,9 @@ class RoiSet(object):
 
         df = filter_df(
             make_df_from_object_ids(
-                acc_raw, acc_obj_ids, expand_box_by=params.expand_box_by
+                acc_raw, acc_obj_ids,
+                expand_box_by=params.expand_box_by,
+                deproject_channel=params.deproject_channel,
             ),
             params.filters,
         )
@@ -404,6 +421,7 @@ class RoiSet(object):
                     r.y1 - r.y0,
                     r.x1 - r.x0
                 )
+                # TODO: use accessor.get_z_argmax
                 zmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
         else:
             bbox_df['zi'] = bbox_zi
@@ -1022,6 +1040,9 @@ class BoundingBoxError(Error):
 class DeserializeRoiSet(Error):
     pass
 
+class NoDeprojectChannelSpecifiedError(Error):
+    pass
+
 class DerivedChannelError(Error):
     pass
 
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index ba553564..497bcb03 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -187,18 +187,18 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         return roiset
 
     def test_classify_by_multiple_channels(self):
-        roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask)
+        roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0))
         roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
         self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
         self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1]))
         return roiset
 
     def test_transfer_classification(self):
-        roiset1 = RoiSet.from_binary_mask(self.stack, self.seg_mask)
+        roiset1 = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0))
 
         # prepare alternative mask and compare
         smoothed_mask = self.seg_mask.apply(lambda x: smooth(x, sig=1.5))
-        roiset2 = RoiSet.from_binary_mask(self.stack, smoothed_mask)
+        roiset2 = RoiSet.from_binary_mask(self.stack, smoothed_mask, params=RoiSetMetaParams(deproject_channel=0))
         dmask = (self.seg_mask.data / 255) + (smoothed_mask.data / 255)
         self.assertTrue(np.all(np.unique(dmask) == [0, 1, 2]))
         total_iou = (dmask == 2).sum() / ((dmask == 1).sum() + (dmask == 2).sum())
@@ -227,6 +227,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
             self.seg_mask,
             params=RoiSetMetaParams(
                 filters={'area': {'min': 1e3, 'max': 1e4}},
+                deproject_channel=0,
             )
         )
         roiset.classify_by(
@@ -295,6 +296,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
                 expand_box_by=(128, 2),
                 mask_type='boxes',
                 filters={'area': {'min': 1e3, 'max': 1e4}},
+                deproject_channel=0,
             )
         )
 
-- 
GitLab