From 222109507ea5e1512fb381200cb1a52af901a2cc Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 9 Oct 2023 16:19:52 +0200
Subject: [PATCH] Corrected problem with autofocusing patches

---
 extensions/chaeo/products.py | 36 +++++++++++++++++++++++++++---------
 1 file changed, 27 insertions(+), 9 deletions(-)

diff --git a/extensions/chaeo/products.py b/extensions/chaeo/products.py
index b056a4ad..98a087e8 100644
--- a/extensions/chaeo/products.py
+++ b/extensions/chaeo/products.py
@@ -1,4 +1,4 @@
-from math import sqrt
+from math import floor, sqrt
 from pathlib import Path
 
 import numpy as np
@@ -41,7 +41,6 @@ def export_patches_from_zstack(
         pad_to: int = 256,
         make_3d: bool = False,
         prefix='patch',
-        projector=lambda x: np.max(x, axis=3, keepdims=True),
         **kwargs
 ):
     assert stack.chroma == 1, 'Expecting monochromatic image data'
@@ -53,10 +52,34 @@ def export_patches_from_zstack(
         sl = mi['slice']
         rbb = mi['relative_bounding_box']
 
+        x0 = rbb['x0']
+        y0 = rbb['y0']
+        x1 = rbb['x1']
+        y1 = rbb['y1']
+
+        patch3d = stack.data[sl]
+        ph, pw, pc, pz = patch3d.shape
+
+        # make a 3d patch
         if make_3d:
-            patch = stack.data[sl]
+            patch = patch3d
+
+        # make a 2d patch, find optimal z-position determined by focus_metric function
+        elif foc := kwargs.get('focus_metric'):
+            sp_sl = np.s_[y0: y1, x0: x1, :, :]
+            subpatch = patch3d[sp_sl]
+
+            patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)
+
+            for ci in range(0, pc):
+                me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)]
+                zif = np.argmax(me)
+                patch[:, :, ci, 0] = patch3d[:, :, ci, zif]
+
+        # make a 2d patch from middle of z-stack
         else:
-            patch = projector(stack.data[sl])
+            zim = floor(pz / 2)
+            patch = patch3d[:, :, :, [zim]]
 
         assert len(patch.shape) == 4
         assert patch.shape[2] == stack.chroma
@@ -65,11 +88,6 @@ def export_patches_from_zstack(
             patch = rescale(patch, rescale_clip)
 
         if kwargs.get('draw_bounding_box') is True:
-            x0 = rbb['x0']
-            y0 = rbb['y0']
-            x1 = rbb['x1']
-            y1 = rbb['y1']
-
             for zi in range(0, patch.shape[3]):
                 patch[:, :, 0, zi] = draw_box_on_patch(
                     patch[:, :, 0, zi],
-- 
GitLab