From 3b72cb37b65c6498b686380b83e5b649df1fcc63 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 27 Nov 2024 21:18:59 +0100
Subject: [PATCH] Slices are 3d

---
 model_server/base/roiset.py | 20 ++++++++++++++------
 tests/base/test_roiset.py   | 11 ++++++++---
 2 files changed, 22 insertions(+), 9 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index caf27aae..9c4abafc 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -372,12 +372,20 @@ def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame:
     assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0']))
     assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0']))
 
-    df['slice'] = df.apply(
-        lambda r:
-        np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)],
-        axis=1,
-        result_type='reduce',
-    )
+    if is_df_3d(df):
+        df['slice'] = df.apply(
+            lambda r:
+            np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.z0): int(r.z1)],
+            axis=1,
+            result_type='reduce',
+        )
+    else:
+        df['slice'] = df.apply(
+            lambda r:
+            np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)],
+            axis=1,
+            result_type='reduce',
+        )
     df['expanded_slice'] = df.apply(
         lambda r:
         np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1],
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 5dc1fd98..b1f8ea68 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -273,7 +273,6 @@ class TestRoiSet3dProducts(unittest.TestCase):
         self.assertGreater(len(roiset.get_df()['zi'].unique()), 1)
         self.assertTrue((df['z1'] - df['z0'] > 1).any())
 
-        # TODO: no labels are actually extending into 3D
         roiset.acc_obj_ids.write(self.where / 'labels.tif')
         return roiset
 
@@ -767,7 +766,10 @@ class TestRoiSetObjectDetection(unittest.TestCase):
         # test bounding box dimensions match those from RoiSet generated directly from segmentation
         self.assertEqual(roiset_seg.count, roiset_bbox.count)
         for i in range(0, roiset_seg.count):
-            self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape)
+            patch_from_seg = patches_seg.iat(0, crop=True)
+            patch_from_bbox = patches_bbox.iat(0, crop=True)
+            self.assertEqual(patch_from_seg.hw, patch_from_bbox.hw)
+            self.assertEqual(patch_from_seg.chroma, patch_from_bbox.chroma)
 
         # test that serialization does not write patch masks
         roiset_ser_path = output_path / 'roiset_from_bbox'
@@ -779,7 +781,10 @@ class TestRoiSetObjectDetection(unittest.TestCase):
         roiset_des = RoiSet.deserialize(self.stack_ch_pa, roiset_ser_path)
         self.assertEqual(roiset_des.count, roiset_bbox.count)
         for i in range(0, roiset_des.count):
-            self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape)
+            patch_from_seg = patches_seg.iat(0, crop=True)
+            patch_from_bbox = patches_bbox.iat(0, crop=True)
+            self.assertEqual(patch_from_seg.hw, patch_from_bbox.hw)
+            self.assertEqual(patch_from_seg.chroma, patch_from_bbox.chroma)
         self.assertTrue((roiset_bbox.get_zmask() == roiset_des.get_zmask()).all())
 
 
-- 
GitLab