From ce711cba1966178fa4f9ed19cf34db1ebb07248c Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 3 Feb 2024 21:53:38 +0100 Subject: [PATCH] slice and mask array objects now in RoiSet's dataframe --- model_server/extensions/chaeo/products.py | 3 +- .../extensions/chaeo/tests/test_zstack.py | 8 +-- model_server/extensions/chaeo/zmask.py | 59 +++++++++---------- 3 files changed, 32 insertions(+), 38 deletions(-) diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index 82699fb2..84535d73 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -61,9 +61,8 @@ def write_patch_to_file(where, fname, yxcz): def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack: patches = [] for ob in roiset.get_df().itertuples('Roi'): - sp_sl = roiset.get_rel_slice_at(ob.Index) patch = np.zeros((ob.ebb_h, ob.ebb_w, 1, 1), dtype='uint8') - patch[sp_sl][:, :, 0, 0] = roiset.get_mask_at(ob.Index) * 255 + patch[ob.relative_slice][:, :, 0, 0] = ob.mask * 255 if pad_to: patch = pad(patch, pad_to) diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index cc9a4297..6fca51b8 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -96,20 +96,18 @@ class TestZStackDerivedDataProducts(unittest.TestCase): def test_zmask_slices_are_valid(self): roiset = self.test_zmask_makes_correct_boxes() - slices = [roiset.get_slice_at(i) for i in roiset.get_df().index] - for s in slices: + for s in roiset.get_slices(): ebb = roiset.acc_raw.data[s] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) def test_zmask_rel_slices_are_valid(self): roiset = self.test_zmask_makes_correct_boxes() - slices = [roiset.get_slice_at(i) for i in roiset.get_df().index] - rel_slices = [roiset.get_rel_slice_at(i) for i in roiset.get_df().index] - for i, s in enumerate(slices): + for i, s in enumerate(roiset.get_slices()): ebb = roiset.acc_raw.data[s] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) + rel_slices = roiset.get_df()['relative_slice'] rbb = ebb[rel_slices[i]] self.assertEqual(len(rbb.shape), 4) self.assertTrue(np.all([si >= 1 for si in rbb.shape])) diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index 2cef8230..b06c276f 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -36,7 +36,9 @@ class RoiSet(object): self.acc_raw = acc_raw self._df = self.filter_df( - self.make_df(self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by), + self.make_df( + self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by + ), params.filters, ) @@ -78,7 +80,14 @@ class RoiSet(object): self.object_class_map = None @staticmethod - def make_df(acc_raw, acc_obj_ids, expand_box_by): + def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: + """ + Build dataframe associate object IDs with summary stats + :param acc_raw: accessor to raw image data + :param acc_obj_ids: accessor to map of object IDs + :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) + :return: pd.DataFrame + """ # build dataframe of objects, assign z index to each object argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') df = ( @@ -117,36 +126,22 @@ class RoiSet(object): assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0'])) assert np.all(df['rel_y1'] <= (df['ebb_x1'] - df['ebb_x0'])) + df['slice'] = df.apply( + lambda r: + np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], + axis=1 + ) + df['relative_slice'] = df.apply( + lambda r: + np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], + axis=1 + ) + df['mask'] = df.apply( + lambda r: (acc_obj_ids == r.label)[r.y0: r.y1, r.x0: r.x1], + axis=1 + ) return df - def get_slice_at(self, idx) -> tuple: - """ - Return slice object in np.s_ format that defines expanded bounding box of object - :param idx: object index (Index in DataFrame, does not necessarily start at zero) - :return: slice object - """ - ob = self.get_df().loc[idx, :].astype('int64') - return np.s_[ob.ebb_y0: ob.ebb_y1, ob.ebb_x0: ob.ebb_x1, :, ob.ebb_z0: ob.ebb_z1 + 1] - - def get_rel_slice_at(self, idx) -> tuple: - """ - Return slice object in np.s_ format that defines bounding box of an object within its expanded bounding box - :param idx: object index (Index in DataFrame, does not necessarily start at zero) - :return: slice object - """ - ob = self.get_df().loc[idx, :].astype('int64') - return np.s_[ob.rel_y0: ob.rel_y1, ob.rel_x0: ob.rel_x1, :, :] - - - def get_mask_at(self, idx) -> np.ndarray: - """ - Return 2D array describing object mask that fills (unexpanded) bounding box at index idx - :param idx: object index (Index in DataFrame, does not necessarily start at zero) - :return: np.ndarray boolean mask - """ - ob = self.get_df().loc[idx, :].astype('int64') - obmask = (self.acc_obj_ids == ob.label) - return obmask[ob.y0: ob.y1, ob.x0: ob.x1] @staticmethod def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: @@ -158,12 +153,14 @@ class RoiSet(object): vmax = val['max'] assert vmin >= 0 query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' - # df.loc[df.query(query_str).index, 'keeper'] = True return df.loc[df.query(query_str).index, :] def get_df(self) -> pd.DataFrame: # TODO: exclude columns that refer to objects return self._df + def get_slices(self) -> pd.Series: + return self.get_df()['slice'] + def add_df_col(self, name, se: pd.Series) -> None: self._df[name] = se -- GitLab