diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 60cb59129cb47981a05f759ab1b6f7dd18e9ffac..f529da223bf2f7303216677dc0ab95ecd6ec0c6d 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -221,6 +221,7 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame """ # build dataframe of objects, assign z index to each object + # TODO: don't assume that channel 0 is the basis of z-argmax if acc_obj_ids.nz == 1: # deproject objects' z-coordinates from argmax of raw image df = pd.DataFrame(regionprops_table( acc_obj_ids.data_xy, @@ -328,7 +329,6 @@ def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor: class RoiSet(object): - # TODO: __init__ to take bounding boxes e.g. from obj det model; flag if overlaps are allowed def __init__( self, acc_raw: GenericImageDataAccessor, @@ -383,17 +383,40 @@ class RoiSet(object): @staticmethod def from_bounding_boxes( acc_raw: GenericImageDataAccessor, - yxhw_list: List, + bbox_df: pd.DataFrame, params: RoiSetMetaParams = RoiSetMetaParams() ): - df = pd.DataFrame([ - { - 'y0': yxhw[0], - 'y1': yxhw[0] + yxhw[2], - 'x0': yxhw[1], - 'x1': yxhw[1] + yxhw[3], - } for yxhw in yxhw_list - ]) + + # deproject if zi is not provided + if 'zi' not in bbox_df.columns: + def _slice_zmax(r): + acc_raw.crop_hw( + r.y0, + r.x0, + r.y1 - r.y0, + r.x1 - r.x0 + ) + zmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') + bbox_df['zi'] = 0 + + df = df_insert_slices( + bbox_df, + acc_raw.shape_dict, + params.get('expand_box_by', 0) + ) + + def _make_binary_mask(r): + # TODO: make square mask array + # acc = InMemoryDataAccessor(acc_obj_ids.data == r.label) + # cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy + return cropped + + df['binary_mask'] = df.apply( + _make_binary_mask, + axis=1, + result_type='reduce', + ) + return RoiSet(acc_raw, df, params) @@ -945,14 +968,14 @@ class RoiSet(object): # TODO: make this work with obj det dataset @staticmethod - def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''): + def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix='') -> Self: """ Create an RoiSet object from saved files and an image accessor :param acc_raw: accessor to image that contains ROIs :param where: path to directory containing RoiSet serialization files, namely dataframe.csv and a subdirectory named tight_patch_masks :param prefix: starting prefix of patch mask filenames - :return: + :return: RoiSet object """ df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 75a9b7752984998f081fd77ecbf8589efb9304e3..6794dac20818a2b5bb9d534ccc50ec1c07353593 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -623,6 +623,32 @@ class TestRoiSetSerialization(unittest.TestCase): t_acc = generate_file_accessor(pt) self.assertTrue(np.all(r_acc.data == t_acc.data)) + def test_create_roiset_from_bounding_boxes(self): + from skimage.measure import label, regionprops, regionprops_table + + + mask = self.seg_mask_3d.data_xyz + labels = label(mask) + table = pd.DataFrame( + regionprops_table(labels) + ).rename( + columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'zi', 'bbox-3': 'y1', 'bbox-4': 'x1'} + ).drop( + columns=['bbox-5'] + ) + table['w'] = table['x1'] - table['x0'] + table['h'] = table['y1'] - table['y0'] + + + # bbox = + self.assertTrue(False) + # RoiSet.from_bounding_boxes( + # self.stack_ch_pa, + # + # ) + + # test segments reside in bounding boxes + class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_compute_polygons(self):