From 009e4b309d610ec521d84b7acbe12e3c80c0c54d Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 27 Nov 2024 20:47:02 +0100
Subject: [PATCH] All pre-existing tests pass

---
 model_server/base/roiset.py | 51 ++++++++++++++++---------------------
 tests/base/test_roiset.py   |  2 ++
 2 files changed, 24 insertions(+), 29 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index dbe24304..ba35a9a9 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -234,6 +234,10 @@ def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Dat
     return dfbb
 
 
+def is_df_3d(df: pd.DataFrame) -> bool:
+    return 'z0' in df.columns and 'zi' in df.columns
+
+
 def make_df_from_object_ids(
         acc_raw,
         acc_obj_ids,
@@ -290,17 +294,15 @@ def make_df_from_object_ids(
 
     elif acc_obj_ids.nz == 1 and acc_raw.nz == 1:  # purely 2d, no z information in dataframe
         df = pd.DataFrame(regionprops_table(
-            acc_obj_ids.data_yxz,
+            acc_obj_ids.data_yx,
             properties=('label', 'area', 'bbox')
         )).rename(columns={
             'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'
         })
 
-        # df['zi'] = 0
-
         def _make_binary_mask(r):
             acc = InMemoryDataAccessor(acc_obj_ids.data == r.label)
-            return acc.get_mono(0).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_yx
+            cropped = acc.get_mono(0).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_yx
             return cropped
 
     else:  # purely 3d: objects' median z-coordinates come from arg of max count in object identities map
@@ -318,24 +320,11 @@ def make_df_from_object_ids(
 
         def _make_binary_mask(r):
             acc = InMemoryDataAccessor(acc_obj_ids.data == r.label)
-
-            # TODO: optional r.z0:z1 + 1 for case of 3d objects
             cropped = acc.get_mono(0).crop_hwd((r.y0, r.x0, r.z0, (r.y1 - r.y0), (r.x1 - r.x0), (r.z1 - r.z0))).data_yxz
-            # TODO: 3d crop method on accessor
             return cropped
 
     df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by)
-
     df_fil = filter_df(df, filters)
-
-    # def _make_binary_mask(r):
-    #     acc = InMemoryDataAccessor(acc_obj_ids.data == r.label)
-    #
-    #     # TODO: optional r.z0:z1 + 1 for case of 3d objects
-    #     # cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_yx
-    #     cropped = acc.get_mono(0).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_yxz[r.z0: r.z1]
-    #     return cropped
-
     df_fil['binary_mask'] = df_fil.apply(
         _make_binary_mask,
         axis=1,
@@ -360,16 +349,15 @@ def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame:
     df['ebb_x0'] = (df.x0 - ebxy).apply(lambda x: max(x, 0))
     df['ebb_x1'] = (df.x1 + ebxy).apply(lambda x: min(x, w))
 
-    # handle based on third dimension
-    if 'z0' in df.columns and 'z1' in df.columns:
+    # handle based on whether bounding box coordinates are 2d or 3d
+    if is_df_3d(df):
         df['ebb_z0'] = (df.z0 - ebz).apply(lambda x: max(x, 0))
         df['ebb_z1'] = (df.z1 + ebz).apply(lambda x: max(x, nz))
-    elif 'zi' in df.columns:
+    else:
+        if 'zi' not in df.columns:
+            df['zi'] = 0
         df['ebb_z0'] = (df.zi - ebz).apply(lambda x: max(x, 0))
         df['ebb_z1'] = (df.zi + ebz).apply(lambda x: min(x, nz))
-    else:
-        df['ebb_z0'] = df.zi
-        df['ebb_z1'] = df.zi
 
     df['ebb_h'] = df['ebb_y1'] - df['ebb_y0']
     df['ebb_w'] = df['ebb_x1'] - df['ebb_x0']
@@ -423,7 +411,7 @@ def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor:
     if 'binary_mask' not in df.columns:
         raise MissingSegmentationError('RoiSet dataframe does not contain segmentation')
 
-    if 'z0' in df.columns and 'zi' in df.columns:  # use 3d coordinates
+    if is_df_3d(df):  # use 3d coordinates
         def _label_obj(r):
             sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.z0:r.z1]
             mask = np.expand_dims(r.binary_mask, 2)
@@ -837,6 +825,7 @@ class RoiSet(object):
         return patches_df
 
     def get_patch_masks(self, pad_to: int = None, expanded: bool = False) -> pd.DataFrame:
+        is_3d = is_df_3d(self._df)
         def _make_patch_mask(roi):
             if expanded:
                 patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
@@ -845,7 +834,10 @@ class RoiSet(object):
                 patch = (roi.binary_mask * 255).astype('uint8')
             if pad_to:
                 patch = pad(patch, pad_to)
-            return np.expand_dims(patch, (2, 3))
+            if is_3d:
+                return np.expand_dims(patch, 3)
+            else:
+                return np.expand_dims(patch, (2, 3))
 
         dfe = self._df.copy()
         dfe['patch_mask'] = dfe.apply(_make_patch_mask, axis=1)
@@ -1203,9 +1195,10 @@ class RoiSet(object):
                 fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}'
                 try:
                     ma_acc = generate_file_accessor(pa_masks / fname)
-                    # TODO: check this on re-serialized test data
-                    # assert ma_acc.chroma == 1 and ma_acc.nz == 1
-                    mask_data = ma_acc.data_yxz / ma_acc.dtype_max
+                    if is_df_3d(df):
+                        mask_data = ma_acc.data_yxz / ma_acc.dtype_max
+                    else:
+                        mask_data = ma_acc.data_yx / ma_acc.dtype_max
                     return mask_data
                 except Exception as e:
                     raise DeserializeRoiSetError(e)
@@ -1214,7 +1207,7 @@ class RoiSet(object):
             id_mask = make_object_ids_from_df(df, acc_raw.shape_dict)
             return cls.from_object_ids(acc_raw, id_mask)
 
-        else:  # assume bounding boxes only
+        else:  # assume bounding boxes, exclusively 2d objects
             df['y'] = df['y0']
             df['x'] = df['x0']
             df['h'] = df['y1'] - df['y0']
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index dd792e7c..46896538 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -252,6 +252,8 @@ class TestRoiSet3dProducts(unittest.TestCase):
 
     where = output_path / 'run_exports_mono_3d'
 
+    # TODO: test serialization/deserialization of 3d patches
+
     def setUp(self) -> None:
         # set up test raw data and segmentation from file
         self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path'])
-- 
GitLab