From 8f3d5bc401303f15b4747652c2fefa9160543112 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 29 Nov 2024 15:09:29 +0100
Subject: [PATCH] Deserialization does not solve deprojection a second time; a
 deprojected RoiSet is serialized as 3D.  Tight patch mask format depends on
 dimensions of ROIs themselves, so 2d PNGs are possible in a 3D set if no ROI
 is thicker than 1

---
 model_server/base/roiset.py        | 11 ++++++-----
 tests/base/test_roiset_pipeline.py | 23 ++++++++++++-----------
 2 files changed, 18 insertions(+), 16 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 509650a2..ccd79c2d 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -320,6 +320,7 @@ def make_df_from_object_ids(
             return rel_argzmax + r.z0
 
         df['zi'] = df.apply(_get_zi_from_label, axis=1)
+        df['nz'] = df['z1'] - df['z0']
 
         def _make_binary_mask(r):
             r = r.convert_dtypes()
@@ -809,7 +810,10 @@ class RoiSet(object):
 
     def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> pd.DataFrame:
         patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded).copy()
-        ext = 'tif' if is_df_3d(patches_df) else 'png'
+        if 'nz' in patches_df.columns and any(patches_df['nz'] > 1):
+            ext = 'tif'
+        else:
+            ext = 'png'
 
         def _export_patch_mask(roi):
             patch = InMemoryDataAccessor.from_mono(roi.patch_mask)
@@ -1214,10 +1218,7 @@ class RoiSet(object):
 
             df['binary_mask'] = df.apply(_read_binary_mask, axis=1)
             id_mask = make_object_ids_from_df(df, acc_raw.shape_dict)
-            if not is_3d and id_mask.nz > 1:
-                return cls.from_object_ids(acc_raw, id_mask.get_mip())
-            else:
-                return cls.from_object_ids(acc_raw, id_mask)
+            return cls.from_object_ids(acc_raw, id_mask)
 
         else:  # assume bounding boxes, exclusively 2d objects
             df['y'] = df['y0']
diff --git a/tests/base/test_roiset_pipeline.py b/tests/base/test_roiset_pipeline.py
index 15c32145..641c887a 100644
--- a/tests/base/test_roiset_pipeline.py
+++ b/tests/base/test_roiset_pipeline.py
@@ -33,16 +33,17 @@ class BaseTestRoiSetMonoProducts(object):
 
     def _get_export_params(self):
         return {
-            'patches_3d': None,
-            'annotated_patches_2d': {
-                'draw_bounding_box': True,
-                'rgb_overlay_channels': [3, None, None],
-                'rgb_overlay_weights': [0.2, 1.0, 1.0],
-                'pad_to': 512,
-            },
-            'patches_2d': {
-                'draw_bounding_box': False,
-                'draw_mask': False,
+            'patches': {
+                '2d_annotated': {
+                    'draw_bounding_box': True,
+                    'rgb_overlay_channels': [3, None, None],
+                    'rgb_overlay_weights': [0.2, 1.0, 1.0],
+                    'pad_to': 512,
+                },
+                '2d': {
+                    'draw_bounding_box': False,
+                    'draw_mask': False,
+                },
             },
             'annotated_zstacks': None,
             'object_classes': True,
@@ -102,7 +103,7 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
             {f'{k}_': v['model'] for k, v in self._get_models().items()},
             **params.dict()
         )
-        self.assertEqual(trace.pop('annotated_patches_2d').count, n_rois)
+        self.assertEqual(trace.pop('patches_2d_annotated').count, n_rois)
         self.assertEqual(trace.pop('patches_2d').count, n_rois)
         trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False)
         self.assertTrue('ob_id' in trace.keys())
-- 
GitLab