From 8f3d5bc401303f15b4747652c2fefa9160543112 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 29 Nov 2024 15:09:29 +0100 Subject: [PATCH] Deserialization does not solve deprojection a second time; a deprojected RoiSet is serialized as 3D. Tight patch mask format depends on dimensions of ROIs themselves, so 2d PNGs are possible in a 3D set if no ROI is thicker than 1 --- model_server/base/roiset.py | 11 ++++++----- tests/base/test_roiset_pipeline.py | 23 ++++++++++++----------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 509650a2..ccd79c2d 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -320,6 +320,7 @@ def make_df_from_object_ids( return rel_argzmax + r.z0 df['zi'] = df.apply(_get_zi_from_label, axis=1) + df['nz'] = df['z1'] - df['z0'] def _make_binary_mask(r): r = r.convert_dtypes() @@ -809,7 +810,10 @@ class RoiSet(object): def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> pd.DataFrame: patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded).copy() - ext = 'tif' if is_df_3d(patches_df) else 'png' + if 'nz' in patches_df.columns and any(patches_df['nz'] > 1): + ext = 'tif' + else: + ext = 'png' def _export_patch_mask(roi): patch = InMemoryDataAccessor.from_mono(roi.patch_mask) @@ -1214,10 +1218,7 @@ class RoiSet(object): df['binary_mask'] = df.apply(_read_binary_mask, axis=1) id_mask = make_object_ids_from_df(df, acc_raw.shape_dict) - if not is_3d and id_mask.nz > 1: - return cls.from_object_ids(acc_raw, id_mask.get_mip()) - else: - return cls.from_object_ids(acc_raw, id_mask) + return cls.from_object_ids(acc_raw, id_mask) else: # assume bounding boxes, exclusively 2d objects df['y'] = df['y0'] diff --git a/tests/base/test_roiset_pipeline.py b/tests/base/test_roiset_pipeline.py index 15c32145..641c887a 100644 --- a/tests/base/test_roiset_pipeline.py +++ b/tests/base/test_roiset_pipeline.py @@ -33,16 +33,17 @@ class BaseTestRoiSetMonoProducts(object): def _get_export_params(self): return { - 'patches_3d': None, - 'annotated_patches_2d': { - 'draw_bounding_box': True, - 'rgb_overlay_channels': [3, None, None], - 'rgb_overlay_weights': [0.2, 1.0, 1.0], - 'pad_to': 512, - }, - 'patches_2d': { - 'draw_bounding_box': False, - 'draw_mask': False, + 'patches': { + '2d_annotated': { + 'draw_bounding_box': True, + 'rgb_overlay_channels': [3, None, None], + 'rgb_overlay_weights': [0.2, 1.0, 1.0], + 'pad_to': 512, + }, + '2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + }, }, 'annotated_zstacks': None, 'object_classes': True, @@ -102,7 +103,7 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): {f'{k}_': v['model'] for k, v in self._get_models().items()}, **params.dict() ) - self.assertEqual(trace.pop('annotated_patches_2d').count, n_rois) + self.assertEqual(trace.pop('patches_2d_annotated').count, n_rois) self.assertEqual(trace.pop('patches_2d').count, n_rois) trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False) self.assertTrue('ob_id' in trace.keys()) -- GitLab