From 3b72cb37b65c6498b686380b83e5b649df1fcc63 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 27 Nov 2024 21:18:59 +0100 Subject: [PATCH] Slices are 3d --- model_server/base/roiset.py | 20 ++++++++++++++------ tests/base/test_roiset.py | 11 ++++++++--- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index caf27aae..9c4abafc 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -372,12 +372,20 @@ def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame: assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0'])) assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0'])) - df['slice'] = df.apply( - lambda r: - np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], - axis=1, - result_type='reduce', - ) + if is_df_3d(df): + df['slice'] = df.apply( + lambda r: + np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.z0): int(r.z1)], + axis=1, + result_type='reduce', + ) + else: + df['slice'] = df.apply( + lambda r: + np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], + axis=1, + result_type='reduce', + ) df['expanded_slice'] = df.apply( lambda r: np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 5dc1fd98..b1f8ea68 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -273,7 +273,6 @@ class TestRoiSet3dProducts(unittest.TestCase): self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) self.assertTrue((df['z1'] - df['z0'] > 1).any()) - # TODO: no labels are actually extending into 3D roiset.acc_obj_ids.write(self.where / 'labels.tif') return roiset @@ -767,7 +766,10 @@ class TestRoiSetObjectDetection(unittest.TestCase): # test bounding box dimensions match those from RoiSet generated directly from segmentation self.assertEqual(roiset_seg.count, roiset_bbox.count) for i in range(0, roiset_seg.count): - self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape) + patch_from_seg = patches_seg.iat(0, crop=True) + patch_from_bbox = patches_bbox.iat(0, crop=True) + self.assertEqual(patch_from_seg.hw, patch_from_bbox.hw) + self.assertEqual(patch_from_seg.chroma, patch_from_bbox.chroma) # test that serialization does not write patch masks roiset_ser_path = output_path / 'roiset_from_bbox' @@ -779,7 +781,10 @@ class TestRoiSetObjectDetection(unittest.TestCase): roiset_des = RoiSet.deserialize(self.stack_ch_pa, roiset_ser_path) self.assertEqual(roiset_des.count, roiset_bbox.count) for i in range(0, roiset_des.count): - self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape) + patch_from_seg = patches_seg.iat(0, crop=True) + patch_from_bbox = patches_bbox.iat(0, crop=True) + self.assertEqual(patch_from_seg.hw, patch_from_bbox.hw) + self.assertEqual(patch_from_seg.chroma, patch_from_bbox.chroma) self.assertTrue((roiset_bbox.get_zmask() == roiset_des.get_zmask()).all()) -- GitLab