From 8ecc14be9d9c7008e02635fd403c299aab07b29c Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Tue, 16 Jul 2024 19:06:21 +0200 Subject: [PATCH] Generally removing necessity of instantiating with a map of object IDs --- model_server/base/roiset.py | 187 ++++++++++++++++++++++-------------- tests/base/test_roiset.py | 6 +- 2 files changed, 118 insertions(+), 75 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 36f9e066..6afa4b1a 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -108,6 +108,19 @@ def _focus_metrics(): 'moment': lambda x: moment(x.flatten(), moment=2), } + +def _filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: + query_str = 'label > 0' # always true + if filters is not None: # parse filters + for k, val in filters.dict(exclude_unset=True).items(): + assert k in ('area') + vmin = val['min'] + vmax = val['max'] + assert vmin >= 0 + query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' + return df.loc[df.query(query_str).index, :] + + # TODO: get overlapping bounding boxes def _filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame: @@ -128,9 +141,50 @@ def _filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame: sdf['overlaps_with'] = second return sdf +def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: + """ + Build dataframe associate object IDs with summary stats + :param acc_raw: accessor to raw image data + :param acc_obj_ids: accessor to map of object IDs + :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) + # :param deproject: assign object's z-position based on argmax of raw data if True + :return: pd.DataFrame + """ + # build dataframe of objects, assign z index to each object + + if acc_obj_ids.nz == 1: # deproject objects' z-coordinates from argmax of raw image + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data[:, :, 0, 0], + intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'), + properties=('label', 'area', 'intensity_mean', 'bbox') + )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) + df['zi'] = df['intensity_mean'].round().astype('int') + + else: # objects' z-coordinates come from arg of max count in object identities map + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data[:, :, 0, :], + properties=('label', 'area', 'bbox') + )).rename(columns={ + 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' + }) + df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax()) + + df = _df_insert_slices(df, acc_raw.shape_dict, expand_box_by) + + # TODO: make this contingent on whether seg is included + df['binary_mask'] = df.apply( + lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], + axis=1, + result_type='reduce', + ) + return df + + +def _df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame: + h = sd['Y'] + w = sd['X'] + nz = sd['Z'] -def _df_insert_slices(df: pd.DataFrame, shape: tuple, expand_box_by) -> pd.DataFrame: - h, w, c, nz = shape df['h'] = df['y1'] - df['y0'] df['w'] = df['x1'] - df['x0'] ebxy, ebz = expand_box_by @@ -185,6 +239,18 @@ def _safe_add(a, g, b): np.iinfo(a.dtype).max ).astype(a.dtype) +def _make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor: + id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16') + if 'binary_mask' not in df.columns: + raise MissingSegmentationError('RoiSet dataframe does not contain segmentation') + + def _label_obj(r): + sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1] + id_mask[sl] = id_mask[sl] + r.label * r.binary_mask + + df.apply(_label_obj, axis=1) + return InMemoryDataAccessor(id_mask) + class RoiSet(object): @@ -192,7 +258,8 @@ class RoiSet(object): def __init__( self, acc_raw: GenericImageDataAccessor, - acc_obj_ids: GenericImageDataAccessor, + # acc_obj_ids: GenericImageDataAccessor, + df: pd.DataFrame, params: RoiSetMetaParams = RoiSetMetaParams(), ): """ @@ -203,19 +270,20 @@ class RoiSet(object): labels its membership in a connected object :param params: optional arguments that influence the definition and representation of ROIs """ - assert acc_obj_ids.chroma == 1 - self.acc_obj_ids = acc_obj_ids + # assert acc_obj_ids.chroma == 1 + # self.acc_obj_ids = acc_obj_ids self.acc_raw = acc_raw self.accs_derived = [] self.params = params - self._df = self.filter_df( - self.make_df_from_object_ids( - self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by - ), - params.filters, - ) + # self._df = self.filter_df( + # self.make_df_from_object_ids( + # self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by + # ), + # params.filters, + # ) + self._df = df self.count = len(self._df) self.object_class_maps = {} # classification results @@ -223,6 +291,24 @@ class RoiSet(object): """Expose ROI meta information via the Pandas.DataFrame API""" return self._df.itertuples(name='Roi') + @staticmethod + def from_object_ids( + acc_raw: GenericImageDataAccessor, + acc_obj_ids: GenericImageDataAccessor, + params: RoiSetMetaParams = RoiSetMetaParams(), + ): + assert acc_obj_ids.chroma == 1 + + df = _filter_df( + _make_df_from_object_ids( + acc_raw, acc_obj_ids, expand_box_by=params.expand_box_by + ), + params.filters, + ) + + return RoiSet(acc_raw, df, params) + + # TODO: add or overload for object detection case @staticmethod def from_binary_mask( @@ -241,50 +327,16 @@ class RoiSet(object): :param params: optional arguments that influence the definition and representation of ROIs :return: object identities map """ - return RoiSet(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) + return __class__.from_object_ids(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) # TODO: generate overlapping RoiSet from multiple masks # call e.g. static adder + @staticmethod - def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: - """ - Build dataframe associate object IDs with summary stats - :param acc_raw: accessor to raw image data - :param acc_obj_ids: accessor to map of object IDs - :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) - # :param deproject: assign object's z-position based on argmax of raw data if True - :return: pd.DataFrame - """ - # build dataframe of objects, assign z index to each object - - if acc_obj_ids.nz == 1: # deproject objects' z-coordinates from argmax of raw image - df = pd.DataFrame(regionprops_table( - acc_obj_ids.data[:, :, 0, 0], - intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'), - properties=('label', 'area', 'intensity_mean', 'bbox') - )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) - df['zi'] = df['intensity_mean'].round().astype('int') - - else: # objects' z-coordinates come from arg of max count in object identities map - df = pd.DataFrame(regionprops_table( - acc_obj_ids.data[:, :, 0, :], - properties=('label', 'area', 'bbox') - )).rename(columns={ - 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' - }) - df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax()) - - df = _df_insert_slices(df, acc_raw.shape, expand_box_by) - - # TODO: make this contingent on whether seg is included - df['binary_mask'] = df.apply( - lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], - axis=1, - result_type='reduce', - ) - return df + def make_df_from_patches(): + pass # TODO: get overlapping segments @@ -298,21 +350,6 @@ class RoiSet(object): dfbb['iou'] = dfbb.apply() - - # TODO: test if overlaps exist - - @staticmethod - def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: - query_str = 'label > 0' # always true - if filters is not None: # parse filters - for k, val in filters.dict(exclude_unset=True).items(): - assert k in ('area') - vmin = val['min'] - vmax = val['max'] - assert vmin >= 0 - query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' - return df.loc[df.query(query_str).index, :] - def get_df(self) -> pd.DataFrame: return self._df @@ -731,26 +768,29 @@ class RoiSet(object): return {} + @property + def acc_obj_ids(self): + return _make_object_ids_from_df(self._df, self.acc_raw.shape_dict) + # TODO: add docstring # TODO: make this work with obj det dataset @staticmethod def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''): df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] - id_mask = np.zeros((*acc_raw.hw, 1, acc_raw.nz), dtype='uint16') - def _label_obj(r): - sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1] + def _read_binary_mask(r): ext = 'png' fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' try: ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname) - bool_mask = ma_acc.data / np.iinfo(ma_acc.data.dtype).max - id_mask[sl] = id_mask[sl] + r.label * bool_mask + assert ma_acc.chroma == 1 and ma_acc.nz == 1 + mask_data = ma_acc.data / np.iinfo(ma_acc.data.dtype).max + return mask_data except Exception as e: raise DeserializeRoiSet(e) - - df.apply(_label_obj, axis=1) - return RoiSet(acc_raw, InMemoryDataAccessor(id_mask)) + df['binary_mask'] = df.apply(_read_binary_mask, axis=1) + id_mask = _make_object_ids_from_df(df, acc_raw.shape_dict) + return RoiSet.from_object_ids(acc_raw, id_mask) def project_stack_from_focal_points( @@ -808,4 +848,7 @@ class DeserializeRoiSet(Error): pass class DerivedChannelError(Error): + pass + +class MissingSegmentationError(Error): pass \ No newline at end of file diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index dcf849bd..8e577f50 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -137,7 +137,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_create_roiset_with_no_objects(self): zero_obmap = InMemoryDataAccessor(np.zeros(self.seg_mask.shape, self.seg_mask.dtype)) - roiset = RoiSet(self.stack_ch_pa, zero_obmap) + roiset = RoiSet.from_object_ids(self.stack_ch_pa, zero_obmap) self.assertEqual(roiset.count, 0) def test_slices_are_valid(self): @@ -592,7 +592,7 @@ class TestRoiSetSerialization(unittest.TestCase): id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) self.assertEqual(self.stack_ch_pa.shape, id_map.shape) - roiset = RoiSet( + roiset = RoiSet.from_object_ids( self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='contours') @@ -605,7 +605,7 @@ class TestRoiSetSerialization(unittest.TestCase): self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3]) self.assertEqual(id_map.nz, 1) - roiset = RoiSet( + roiset = RoiSet.from_object_ids( self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='contours') -- GitLab