From ee6740a1d59f1525ce63bdf047a0af5fd8a3e3e8 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 7 Dec 2024 08:58:31 +0100
Subject: [PATCH] Getting a patch mask can optionally update patch focus

---
 model_server/base/roiset.py | 22 +++++++++++++++-------
 tests/base/test_roiset.py   | 26 ++++++++++++++++++++++++++
 2 files changed, 41 insertions(+), 7 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 510b9fe3..0a8775e6 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -40,6 +40,10 @@ class PatchParams(BaseModel):
     pad_to: Union[int, None] = 256
     expanded: bool = False
     force_tif: bool = False
+    update_focus_zi: bool = Field(
+        False,
+        description='If generating 2d patches with a 3d focus metric, update the RoiSet patch focus zi accordingly'
+    )
 
 class AnnotatedPatchParams(PatchParams):
     bounding_box_channel: int = 1
@@ -888,6 +892,7 @@ class RoiSet(object):
             rgb_overlay_weights: list = [1.0, 1.0, 1.0],
             white_channel: int = None,
             expanded=False,
+            update_focus_zi=False,
             **kwargs
     ) -> pd.Series:
 
@@ -937,7 +942,7 @@ class RoiSet(object):
             else:
                 stack = raw.data
 
-        def _make_patch(roi):
+        def _make_patch(roi):  # extract, focus, and annotate a patch
             if expanded:
                 patch3d = stack[roi.expanded_slice]
                 subpatch = patch3d[roi.relative_slice]
@@ -947,11 +952,11 @@ class RoiSet(object):
 
             ph, pw, pc, pz = patch3d.shape
 
-            # make a 3d patch
+            # make a 3d patch, focus stays where it is
             if make_3d:
                 patch = patch3d.copy()
+                zif = roi.zi
 
-            # TODO: somehow persist the result of patch focusing for use by patch mask exporter
             # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
             elif focus_metric is not None:
                 foc = focus_metrics()[focus_metric]
@@ -965,8 +970,8 @@ class RoiSet(object):
 
             # make a 2d patch from middle of z-stack
             else:
-                zim = floor(pz / 2)
-                patch = patch3d[:, :, :, [zim]]
+                zif = floor(pz / 2)
+                patch = patch3d[:, :, :, [zif]]
 
             assert len(patch.shape) == 4
 
@@ -1012,9 +1017,12 @@ class RoiSet(object):
 
             if pad_to and expanded:
                 patch = pad(patch, pad_to)
-            return patch
+            return {'patch': patch, 'zif': zif}
 
-        return self._df.apply(lambda r: _make_patch(r), axis=1)
+        df_processed_patches = self._df.apply(lambda r: _make_patch(r), axis=1, result_type='expand')
+        if update_focus_zi:
+            self._df['zi'] = df_processed_patches['zif']
+        return df_processed_patches['patch']
 
     @property
     def classification_columns(self):
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index e0b9168d..376b8b44 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -296,6 +296,7 @@ class TestRoiSet3dProducts(unittest.TestCase):
                     'draw_mask': False,
                     'pad_to': None,
                     'rgb_overlay_channels': None,
+                    'update_focus_zi': False,
                 },
                 '3d_masks': {
                     'make_3d': True,
@@ -324,6 +325,31 @@ class TestRoiSet3dProducts(unittest.TestCase):
         self.assertTrue(any([_get_nz_from_file(f) > 1 for f in res['patches_3d_masks']]))
         self.assertTrue(all([_get_nz_from_file(f) == 1 for f in res['patches_2d_masks']]))
 
+    def test_run_update_focus_during_patch_export(self):
+        p = RoiSetExportParams(**{
+            'patches': {
+                '2d': {
+                    'make_3d': False,
+                    'focus_metric': 'max_sobel',
+                    'white_channel': -1,
+                    'draw_bounding_box': False,
+                    'draw_mask': False,
+                    'pad_to': None,
+                    'rgb_overlay_channels': None,
+                    'update_focus_zi': True,
+                },
+            },
+        })
+        roiset = self.test_create_roiset_from_3d_obj_ids()
+        starting_zi = roiset.get_df()['zi']
+
+        res = roiset.run_exports(
+            self.where,
+            prefix='test',
+            params=p
+        )
+        updated_zi = roiset.get_df()['zi']
+        self.assertTrue((starting_zi != updated_zi).any())
 
     def test_run_export_mono_3d_labels_overlay(self):
         # export via .run_exports() method
-- 
GitLab