From dde6c75e13bd28d306a698ea4f8edbf3af1e2bec Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 29 Nov 2024 12:03:58 +0100
Subject: [PATCH] Patch exports are passed to a single parameter now, but still
 working out bugs

---
 model_server/base/roiset.py | 68 +++++++++++++++------------
 tests/base/test_roiset.py   | 93 +++++++++++++++++++------------------
 2 files changed, 87 insertions(+), 74 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 70776fd0..46a7f114 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -26,6 +26,7 @@ from .process import mask_largest_object
 
 
 class PatchParams(BaseModel):
+    make_3d: bool = False
     white_channel: Union[int, None] = None
     channels: Union[List[int], None] = None
     draw_bounding_box: bool = False
@@ -72,9 +73,7 @@ class RoiSetMetaParams(BaseModel):
 
 
 class RoiSetExportParams(BaseModel):
-    patches_3d: Union[PatchParams, None] = None
-    annotated_patches_2d: Union[AnnotatedPatchParams, None] = None
-    patches_2d: Union[PatchParams, None] = None
+    patches: Union[Dict[str, PatchParams], None] = None
     annotated_zstacks: Union[AnnotatedZStackParams, None] = None
     object_classes: bool = False
     labels_overlay: Union[RoiSetLabelsOverlayParams, None] = None
@@ -992,6 +991,7 @@ class RoiSet(object):
                 patch = pad(patch, pad_to)
             return patch
 
+        # TODO: just return needed rows, without DataFrame copy
         dfe = self._df.copy()
         dfe['patch'] = dfe.apply(lambda r: _make_patch(r), axis=1)
         return dfe
@@ -1024,47 +1024,55 @@ class RoiSet(object):
         if not self.count:
             return
 
-        for k in params.dict().keys():
-            pr = prefix
-            subdir = Path(k)
-            if 'patches' in k and params.write_patches_to_subdirectory:
-                subdir = subdir / pr
-            kp = params.dict()[k]
+        for k, kp in params.dict().items():
             if kp is None:
                 continue
-            if k == 'patches_3d':
-                df_exp = self.export_patches(
-                    where / subdir, prefix=pr, make_3d=True, **kp
-                )
-                record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
-            if k == 'annotated_patches_2d':
-                df_exp = self.export_patches(
-                    where / subdir, prefix=pr, make_3d=False, **kp,
-                )
-                record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
-            if k == 'patches_2d':
-                df_exp = self.export_patches(
-                    where / subdir, prefix=pr, make_3d=False, **kp
-                )
-                self._df = self._df.join(df_exp.patch_path.apply(lambda x: str(subdir / x)))
-                self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1)
-                record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
+            if k == 'patches':
+                for pk, pp in kp.items():
+                    product_name = f'patches_{pk}'
+                    subdir = Path(product_name)
+                    if params.write_patches_to_subdirectory:
+                        subdir = subdir / prefix
+
+                    df_exports = self.export_patches(where / subdir, prefix=prefix, **pp)
+
+                    df_patch_paths = df_exports.patch_path.apply(lambda x: str(subdir / x))
+                    self._df = self._df.join(df_patch_paths)  # TODO: rename his column
+                    self._df[f'{product_name}_id'] = self._df.apply(lambda _: uuid4(), axis=1)
+                    record[product_name] = list(df_patch_paths)
+            # if k == 'patches_3d':
+            #     df_exp = self.export_patches(
+            #         where / k, prefix=prefix, make_3d=True, **kp
+            #     )
+            #     record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
+            # if k == 'annotated_patches_2d':
+            #     df_exp = self.export_patches(
+            #         where / k, prefix=prefix, make_3d=False, **kp,
+            #     )
+            #     record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
+            # if k == 'patches_2d':
+            #     df_exp = self.export_patches(
+            #         where / k, prefix=prefix, make_3d=False, **kp
+            #     )
+            #     self._df = self._df.join(df_exp.patch_path.apply(lambda x: str(k / x)))
+            #     self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1)
+            #     record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
             if k == 'annotated_zstacks':
-                record[k] = str(Path(k) / self.export_annotated_zstack(where / subdir, prefix=pr, **kp))
+                record[k] = str(Path(k) / self.export_annotated_zstack(where / k, prefix=prefix, **kp))
             if k == 'object_classes':
                 for n in self.classification_columns:
-                    fp = where / subdir / n / (pr + '.tif')
+                    fp = where / k / n / (prefix + '.tif')
                     write_accessor_data_to_file(fp, self.get_object_class_map(n))
                     record[f'{k}_{n}'] = str(fp)
             if k == 'derived_channels':
                 record[k] = []
                 for di, dacc in enumerate(self.accs_derived):
-                    fp = where / subdir / f'dc{di:01d}.tif'
+                    fp = where / k / f'dc{di:01d}.tif'
                     fp.parent.mkdir(exist_ok=True, parents=True)
                     dacc.export_pyxcz(fp)
                     record[k].append(str(fp))
             if k == 'labels_overlay':
-                fn = self.export_object_identities_overlay_map(where / subdir, prefix=prefix, **kp)
+                fn = self.export_object_identities_overlay_map(where / k, prefix=prefix, **kp)
                 record[k] = str(Path(k) / fn)
 
         # export dataframe and patch masks
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index f0a099f2..5640c7f6 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -256,6 +256,7 @@ class TestRoiSet3dProducts(unittest.TestCase):
         self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path'])
         self.stack_ch_pa = self.stack.get_mono(params['segmentation_channel'])
         self.seg_mask_3d = generate_file_accessor(data['multichannel_zstack_mask3d']['path'])
+        self.where.mkdir(parents=True, exist_ok=True)
 
     def test_create_roiset_from_3d_obj_ids(self):
         roiset = RoiSet.from_binary_mask(
@@ -276,12 +277,15 @@ class TestRoiSet3dProducts(unittest.TestCase):
 
     def test_run_export_mono_3d_patch(self):
         p = RoiSetExportParams(**{
-            'patches_3d': {
-                'white_channel': -1,
-                'draw_bounding_box': False,
-                'draw_mask': False,
-                'pad_to': None,
-                'rgb_overlay_channels': None,
+            'patches': {
+                '3d': {
+                    'make_3d': True,
+                    'white_channel': -1,
+                    'draw_bounding_box': False,
+                    'draw_mask': False,
+                    'pad_to': None,
+                    'rgb_overlay_channels': None,
+                },
             },
         })
 
@@ -511,21 +515,19 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
     def test_run_exports(self):
         p = RoiSetExportParams(**{
-            'patches_3d': {},
-            'annotated_patches_2d': {
-                'white_channel': 3,
-                'draw_bounding_box': True,
-                'rgb_overlay_channels': [3, None, None],
-                'rgb_overlay_weights': [0.2, 1.0, 1.0],
-                'pad_to': 512,
-            },
-            'patches_2d': {
-                'white_channel': 3,
-                'draw_bounding_box': False,
-                'draw_mask': False,
-            },
-            'patch_masks': {
-                'pad_to': 256,
+            'patches': {
+                'annotated_patches_2d': {
+                    'white_channel': 3,
+                    'draw_bounding_box': True,
+                    'rgb_overlay_channels': [3, None, None],
+                    'rgb_overlay_weights': [0.2, 1.0, 1.0],
+                    'pad_to': 512,
+                },
+                'patches_2d': {
+                    'white_channel': 3,
+                    'draw_bounding_box': False,
+                    'draw_mask': False,
+                },
             },
             'annotated_zstacks': {},
             'object_classes': True,
@@ -561,18 +563,19 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
     def test_get_interm_prods(self):
         p = RoiSetExportParams(**{
-            'patches_3d': None,
-            'annotated_patches_2d': {
-                'white_channel': 3,
-                'draw_bounding_box': True,
-                'rgb_overlay_channels': [3, None, None],
-                'rgb_overlay_weights': [0.2, 1.0, 1.0],
-                'pad_to': 512,
-            },
-            'patches_2d': {
-                'white_channel': 3,
-                'draw_bounding_box': False,
-                'draw_mask': False,
+            'patches': {
+                'patches_2d': {
+                    'white_channel': 3,
+                    'draw_bounding_box': False,
+                    'draw_mask': False,
+                },
+                'annotated_patches_2d': {
+                    'white_channel': 3,
+                    'draw_bounding_box': True,
+                    'rgb_overlay_channels': [3, None, None],
+                    'rgb_overlay_weights': [0.2, 1.0, 1.0],
+                    'pad_to': 512,
+                },
             },
             'annotated_zstacks': {},
             'object_classes': True,
@@ -610,8 +613,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
                 'pad_to': 256,
             },
         })
-        self.assertTrue(hasattr(p.patches_2d, 'pad_to'))
-        self.assertTrue(hasattr(p.patches_2d, 'expanded'))
+        self.assertTrue(hasattr(p.patches['2d'], 'pad_to'))
+        self.assertTrue(hasattr(p.patches['2d'], 'expanded'))
 
         where = output_path / 'run_exports_expanded_2d_patch'
         res = self.roiset.run_exports(
@@ -629,17 +632,19 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
     def test_run_export_mono_2d_patch(self):
         p = RoiSetExportParams(**{
-            'patches_2d': {
-                'white_channel': -1,
-                'draw_bounding_box': False,
-                'draw_mask': False,
-                'expanded': True,
-                'pad_to': 256,
-                'rgb_overlay_channels': None,
+            'patches': {
+                '2d': {
+                    'white_channel': -1,
+                    'draw_bounding_box': False,
+                    'draw_mask': False,
+                    'expanded': True,
+                    'pad_to': 256,
+                    'rgb_overlay_channels': None,
+                },
             },
         })
-        self.assertTrue(hasattr(p.patches_2d, 'pad_to'))
-        self.assertTrue(hasattr(p.patches_2d, 'expanded'))
+        self.assertTrue(hasattr(p.patches['2d'], 'pad_to'))
+        self.assertTrue(hasattr(p.patches['2d'], 'expanded'))
 
         where = output_path / 'run_exports_mono_2d_patch'
         res = self.roiset.run_exports(
-- 
GitLab