diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index f624460f1fed5d993e93f4a8427303eb3f95ab3e..9f7de71687db8bdc913201116e2ffe3f4b3e47e3 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1,7 +1,7 @@ import itertools from math import sqrt, floor from pathlib import Path -from typing import List, Union +from typing import Dict, List, Union from typing_extensions import Self from uuid import uuid4 @@ -386,12 +386,17 @@ class RoiSet(object): @staticmethod def from_bounding_boxes( acc_raw: GenericImageDataAccessor, - bbox_df: pd.DataFrame, + bbox_yxhw: List[Dict], + bbox_zi: Union[List[int], int] = None, params: RoiSetMetaParams = RoiSetMetaParams() ): + bbox_df = pd.DataFrame(bbox_yxhw) + if list(bbox_df.columns.str.upper().sort_values()) != ['H', 'W', 'X', 'Y']: + raise BoundingBoxError(f'Expecting bounding box coordinates Y, X, H, and W, not {list(bbox_df.columns)}') - # deproject if zi is not provided - if 'zi' not in bbox_df.columns: + + # deproject if zi is not specified + if bbox_zi is None: def _slice_zmax(r): acc_raw.crop_hw( r.y0, @@ -400,7 +405,8 @@ class RoiSet(object): r.x1 - r.x0 ) zmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') - bbox_df['zi'] = 0 + else: + bbox_df['zi'] = bbox_zi df = df_insert_slices( bbox_df, @@ -1010,6 +1016,9 @@ class RoiSet(object): class Error(Exception): pass +class BoundingBoxError(Error): + pass + class DeserializeRoiSet(Error): pass diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index afc75f35436f360f3d51e666a805638773d02269..ba553564dab20555b6f1691d1b71eceb8bc3dd0e 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -626,6 +626,14 @@ class TestRoiSetSerialization(unittest.TestCase): self.assertTrue(ref_roiset.contains_segmentation) self.assertTrue(ref_roiset.contains_segmentation) +class TestRoiSetObjectDetection(unittest.TestCase): + + def setUp(self) -> None: + # set up test raw data and segmentation from file + self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) + self.stack_ch_pa = self.stack.get_mono(params['segmentation_channel']) + self.seg_mask_3d = generate_file_accessor(data['multichannel_zstack_mask3d']['path']) + def test_create_roiset_from_bounding_boxes(self): from skimage.measure import label, regionprops, regionprops_table