From 546dbc1dd6c3613fc78aec95ca45d0f0d4efc808 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 23 Sep 2024 10:14:59 +0200
Subject: [PATCH] Object detection RoiSet generator passes meaningful test

---
 model_server/base/roiset.py |  1 +
 tests/base/test_roiset.py   | 10 +++++-----
 2 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 7a95a947..954e878d 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -91,6 +91,7 @@ def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connec
             la_3d[:, :, 0, zi] = la_2d
         return InMemoryDataAccessor(la_3d)
     else:
+        # TODO: call argmax z method
         return InMemoryDataAccessor(
             label(
                 acc_seg_mask.data_xyz.max(axis=-1),
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 87a1d889..9de818ee 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -625,8 +625,6 @@ class TestRoiSetSerialization(unittest.TestCase):
             t_acc = generate_file_accessor(pt)
             self.assertTrue(np.all(r_acc.data == t_acc.data))
 
-        self.assertTrue(ref_roiset.contains_segmentation)
-        self.assertTrue(ref_roiset.contains_segmentation)
 
 class TestRoiSetObjectDetection(unittest.TestCase):
 
@@ -660,12 +658,14 @@ class TestRoiSetObjectDetection(unittest.TestCase):
         self.assertEqual(len(table), patches_bbox.count)
 
         # roiset w/ seg for comparison
-        roiset_seg = RoiSet.from_binary_mask(self.stack_ch_pa, mask)
+        roiset_seg = RoiSet.from_binary_mask(self.stack_ch_pa, mask, allow_3d=True)
         patches_seg = roiset_seg.get_patches_acc()
 
-        # TODO: test segments reside in bounding boxes
+        # test bounding box dimensions match those from RoiSet generated directly from segmentation
         self.assertEqual(roiset_seg.count, roiset_bbox.count)
-        self.assertTrue(False)
+        for i in range(0, roiset_seg.count):
+            self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape)
+
 
 class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
-- 
GitLab