From 780e01787179c4515c1e7f1857136cb8bf4545cd Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 20 Sep 2024 10:15:26 +0200 Subject: [PATCH] Roughed in more bounding box generator --- model_server/base/roiset.py | 19 ++++++++++++++----- tests/base/test_roiset.py | 8 ++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index f624460f..9f7de716 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1,7 +1,7 @@ import itertools from math import sqrt, floor from pathlib import Path -from typing import List, Union +from typing import Dict, List, Union from typing_extensions import Self from uuid import uuid4 @@ -386,12 +386,17 @@ class RoiSet(object): @staticmethod def from_bounding_boxes( acc_raw: GenericImageDataAccessor, - bbox_df: pd.DataFrame, + bbox_yxhw: List[Dict], + bbox_zi: Union[List[int], int] = None, params: RoiSetMetaParams = RoiSetMetaParams() ): + bbox_df = pd.DataFrame(bbox_yxhw) + if list(bbox_df.columns.str.upper().sort_values()) != ['H', 'W', 'X', 'Y']: + raise BoundingBoxError(f'Expecting bounding box coordinates Y, X, H, and W, not {list(bbox_df.columns)}') - # deproject if zi is not provided - if 'zi' not in bbox_df.columns: + + # deproject if zi is not specified + if bbox_zi is None: def _slice_zmax(r): acc_raw.crop_hw( r.y0, @@ -400,7 +405,8 @@ class RoiSet(object): r.x1 - r.x0 ) zmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') - bbox_df['zi'] = 0 + else: + bbox_df['zi'] = bbox_zi df = df_insert_slices( bbox_df, @@ -1010,6 +1016,9 @@ class RoiSet(object): class Error(Exception): pass +class BoundingBoxError(Error): + pass + class DeserializeRoiSet(Error): pass diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index afc75f35..ba553564 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -626,6 +626,14 @@ class TestRoiSetSerialization(unittest.TestCase): self.assertTrue(ref_roiset.contains_segmentation) self.assertTrue(ref_roiset.contains_segmentation) +class TestRoiSetObjectDetection(unittest.TestCase): + + def setUp(self) -> None: + # set up test raw data and segmentation from file + self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) + self.stack_ch_pa = self.stack.get_mono(params['segmentation_channel']) + self.seg_mask_3d = generate_file_accessor(data['multichannel_zstack_mask3d']['path']) + def test_create_roiset_from_bounding_boxes(self): from skimage.measure import label, regionprops, regionprops_table -- GitLab