diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 6afa4b1ab7a4260034569b5e15f6075433ca2ab8..f0af11a76ed45136ce5382b09bab315e972fd19c 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -172,8 +172,15 @@ def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFram df = _df_insert_slices(df, acc_raw.shape_dict, expand_box_by) # TODO: make this contingent on whether seg is included + def _make_binary_mask(r): + acc = InMemoryDataAccessor(acc_obj_ids.data == r.label) + acc.get_mono(0, mip=True) + cropped = acc.crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data + return cropped + df['binary_mask'] = df.apply( - lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], + # lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], + _make_binary_mask, axis=1, result_type='reduce', ) @@ -327,7 +334,7 @@ class RoiSet(object): :param params: optional arguments that influence the definition and representation of ROIs :return: object identities map """ - return __class__.from_object_ids(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) + return RoiSet.from_object_ids(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) # TODO: generate overlapping RoiSet from multiple masks @@ -519,13 +526,14 @@ class RoiSet(object): patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255 else: patch = np.zeros((roi.y1 - roi.y0, roi.x1 - roi.x0, 1, 1), dtype='uint8') - patch[:, :, 0, 0] = roi.binary_mask * 255 + patch = roi.binary_mask * 255 if pad_to: patch = pad(patch, pad_to) return patch dfe = self._df.copy() + # TODO: can just pass function handle dfe['patch_mask'] = dfe.apply(lambda r: _make_patch_mask(r), axis=1) return dfe