diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 9f7de71687db8bdc913201116e2ffe3f4b3e47e3..ed8f44f873d059bb5039b1970558e398ce516e25 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -52,6 +52,7 @@ class RoiFilter(BaseModel): class RoiSetMetaParams(BaseModel): filters: Union[RoiFilter, None] = None expand_box_by: List[int] = [128, 0] + deproject_channel: Union[int, None] = None class RoiSetExportParams(BaseModel): @@ -210,23 +211,33 @@ def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Dat return dfbb -def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: +def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_channel=None) -> pd.DataFrame: """ - Build dataframe associate object IDs with summary stats + Build dataframe that associate object IDs with summary stats; :param acc_raw: accessor to raw image data :param acc_obj_ids: accessor to map of object IDs :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) - + :param deproject_channel: if objects' z-coordinates are not specified, compute them based on argmax of this channel :return: pd.DataFrame """ # build dataframe of objects, assign z index to each object - # TODO: don't assume that channel 0 is the basis of z-argmax - # TODO: :param deproject: assign object's z-position based on argmax of raw data if True - if acc_obj_ids.nz == 1: # deproject objects' z-coordinates from argmax of raw image + if acc_obj_ids.nz == 1 and acc_raw.nz > 1: + + if deproject_channel is None or deproject_channel >= acc_raw.chroma or deproject_channel < 0: + if acc_raw.chroma == 1: + deproject_channel = 0 + else: + raise NoDeprojectChannelSpecifiedError( + f'When labeling objects, either their z-coordinates or a valid deprojection channel are required.' + ) + acc_raw.get_mono(deproject_channel) + + zi_map = acc_raw.get_mono(deproject_channel).get_z_argmax().data_xy.astype('uint16') + assert len(zi_map.shape) == 2 df = pd.DataFrame(regionprops_table( acc_obj_ids.data_xy, - intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'), + intensity_image=zi_map, properties=('label', 'area', 'intensity_mean', 'bbox') )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) df['zi'] = df['intensity_mean'].round().astype('int') @@ -238,7 +249,11 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame )).rename(columns={ 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' }) - df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax()) + + def _get_zi_from_label(la): + return acc_obj_ids.apply(lambda x: x == la).get_focus_vector().argmax() + + df['zi'] = df['label'].apply(_get_zi_from_label) df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by) @@ -374,7 +389,9 @@ class RoiSet(object): df = filter_df( make_df_from_object_ids( - acc_raw, acc_obj_ids, expand_box_by=params.expand_box_by + acc_raw, acc_obj_ids, + expand_box_by=params.expand_box_by, + deproject_channel=params.deproject_channel, ), params.filters, ) @@ -404,6 +421,7 @@ class RoiSet(object): r.y1 - r.y0, r.x1 - r.x0 ) + # TODO: use accessor.get_z_argmax zmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') else: bbox_df['zi'] = bbox_zi @@ -1022,6 +1040,9 @@ class BoundingBoxError(Error): class DeserializeRoiSet(Error): pass +class NoDeprojectChannelSpecifiedError(Error): + pass + class DerivedChannelError(Error): pass diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index ba553564dab20555b6f1691d1b71eceb8bc3dd0e..497bcb031b909af68cb5ea3fc11c043794000d84 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -187,18 +187,18 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): return roiset def test_classify_by_multiple_channels(self): - roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask) + roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0)) roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1])) return roiset def test_transfer_classification(self): - roiset1 = RoiSet.from_binary_mask(self.stack, self.seg_mask) + roiset1 = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0)) # prepare alternative mask and compare smoothed_mask = self.seg_mask.apply(lambda x: smooth(x, sig=1.5)) - roiset2 = RoiSet.from_binary_mask(self.stack, smoothed_mask) + roiset2 = RoiSet.from_binary_mask(self.stack, smoothed_mask, params=RoiSetMetaParams(deproject_channel=0)) dmask = (self.seg_mask.data / 255) + (smoothed_mask.data / 255) self.assertTrue(np.all(np.unique(dmask) == [0, 1, 2])) total_iou = (dmask == 2).sum() / ((dmask == 1).sum() + (dmask == 2).sum()) @@ -227,6 +227,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.seg_mask, params=RoiSetMetaParams( filters={'area': {'min': 1e3, 'max': 1e4}}, + deproject_channel=0, ) ) roiset.classify_by( @@ -295,6 +296,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa expand_box_by=(128, 2), mask_type='boxes', filters={'area': {'min': 1e3, 'max': 1e4}}, + deproject_channel=0, ) )