From 780e01787179c4515c1e7f1857136cb8bf4545cd Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 20 Sep 2024 10:15:26 +0200
Subject: [PATCH] Roughed in more bounding box generator

---
 model_server/base/roiset.py | 19 ++++++++++++++-----
 tests/base/test_roiset.py   |  8 ++++++++
 2 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index f624460f..9f7de716 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -1,7 +1,7 @@
 import itertools
 from math import sqrt, floor
 from pathlib import Path
-from typing import List, Union
+from typing import Dict, List, Union
 from typing_extensions import Self
 from uuid import uuid4
 
@@ -386,12 +386,17 @@ class RoiSet(object):
     @staticmethod
     def from_bounding_boxes(
         acc_raw: GenericImageDataAccessor,
-        bbox_df: pd.DataFrame,
+        bbox_yxhw: List[Dict],
+        bbox_zi: Union[List[int], int] = None,
         params: RoiSetMetaParams = RoiSetMetaParams()
     ):
+        bbox_df = pd.DataFrame(bbox_yxhw)
+        if list(bbox_df.columns.str.upper().sort_values()) != ['H', 'W', 'X', 'Y']:
+            raise BoundingBoxError(f'Expecting bounding box coordinates Y, X, H, and W, not {list(bbox_df.columns)}')
 
-        # deproject if zi is not provided
-        if 'zi' not in bbox_df.columns:
+
+        # deproject if zi is not specified
+        if bbox_zi is None:
             def _slice_zmax(r):
                 acc_raw.crop_hw(
                     r.y0,
@@ -400,7 +405,8 @@ class RoiSet(object):
                     r.x1 - r.x0
                 )
                 zmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
-            bbox_df['zi'] = 0
+        else:
+            bbox_df['zi'] = bbox_zi
 
         df = df_insert_slices(
             bbox_df,
@@ -1010,6 +1016,9 @@ class RoiSet(object):
 class Error(Exception):
     pass
 
+class BoundingBoxError(Error):
+    pass
+
 class DeserializeRoiSet(Error):
     pass
 
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index afc75f35..ba553564 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -626,6 +626,14 @@ class TestRoiSetSerialization(unittest.TestCase):
         self.assertTrue(ref_roiset.contains_segmentation)
         self.assertTrue(ref_roiset.contains_segmentation)
 
+class TestRoiSetObjectDetection(unittest.TestCase):
+
+    def setUp(self) -> None:
+        # set up test raw data and segmentation from file
+        self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path'])
+        self.stack_ch_pa = self.stack.get_mono(params['segmentation_channel'])
+        self.seg_mask_3d = generate_file_accessor(data['multichannel_zstack_mask3d']['path'])
+
     def test_create_roiset_from_bounding_boxes(self):
         from skimage.measure import label, regionprops, regionprops_table
 
-- 
GitLab