From 8e56b0e634c0f8681dde2c85a5afa96d543960dc Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Tue, 16 Jul 2024 20:58:44 +0200 Subject: [PATCH] Need to standardize whether binary_mask is (Y, X, 1, 1) or (Y, X) --- model_server/base/roiset.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 6afa4b1a..f0af11a7 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -172,8 +172,15 @@ def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFram df = _df_insert_slices(df, acc_raw.shape_dict, expand_box_by) # TODO: make this contingent on whether seg is included + def _make_binary_mask(r): + acc = InMemoryDataAccessor(acc_obj_ids.data == r.label) + acc.get_mono(0, mip=True) + cropped = acc.crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data + return cropped + df['binary_mask'] = df.apply( - lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], + # lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], + _make_binary_mask, axis=1, result_type='reduce', ) @@ -327,7 +334,7 @@ class RoiSet(object): :param params: optional arguments that influence the definition and representation of ROIs :return: object identities map """ - return __class__.from_object_ids(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) + return RoiSet.from_object_ids(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) # TODO: generate overlapping RoiSet from multiple masks @@ -519,13 +526,14 @@ class RoiSet(object): patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255 else: patch = np.zeros((roi.y1 - roi.y0, roi.x1 - roi.x0, 1, 1), dtype='uint8') - patch[:, :, 0, 0] = roi.binary_mask * 255 + patch = roi.binary_mask * 255 if pad_to: patch = pad(patch, pad_to) return patch dfe = self._df.copy() + # TODO: can just pass function handle dfe['patch_mask'] = dfe.apply(lambda r: _make_patch_mask(r), axis=1) return dfe -- GitLab