From 1092b9c55d898912c82ecb14ed64bee6b239b108 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 5 Oct 2023 15:25:29 +0200
Subject: [PATCH] Tested operation to flatten image based on focal points from
 body mask

---
 .../chaeo/examples/batch_run_patches.py       |  2 +-
 extensions/chaeo/tests/test_zstack.py         | 24 +++++++++
 extensions/chaeo/workflows.py                 | 10 ++++
 extensions/chaeo/zmask.py                     | 50 ++++++++++++++++++-
 4 files changed, 84 insertions(+), 2 deletions(-)

diff --git a/extensions/chaeo/examples/batch_run_patches.py b/extensions/chaeo/examples/batch_run_patches.py
index cebc0b40..fb96be1a 100644
--- a/extensions/chaeo/examples/batch_run_patches.py
+++ b/extensions/chaeo/examples/batch_run_patches.py
@@ -51,7 +51,7 @@ if __name__ == '__main__':
         }
 
         result = export_patches_from_multichannel_zstack(**export_kwargs)
-
+        break
         # parse and record results
         df = result['dataframe']
         df['filename'] = ff.name
diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py
index f2394889..39da79b5 100644
--- a/extensions/chaeo/tests/test_zstack.py
+++ b/extensions/chaeo/tests/test_zstack.py
@@ -91,3 +91,27 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             make_3d=True)
         self.assertGreaterEqual(len(files), 1)
 
+    def test_flatten_image(self):
+        zmask, meta, df, interm = build_zmask_from_object_mask(
+            self.obmap.get_one_channel_data(0),
+            self.stack.get_one_channel_data(1),
+            mask_type='boxes',
+        )
+
+        from extensions.chaeo.zmask import project_stack_from_focal_points
+
+        dff = df[df['keeper'] == True]
+
+        img = project_stack_from_focal_points(
+            dff['centroid-0'].to_numpy(),
+            dff['centroid-1'].to_numpy(),
+            dff['zi'].to_numpy(),
+            self.stack.get_one_channel_data(1),
+        )
+
+        self.assertEqual(img.shape[0:2], self.stack.shape[0:2])
+
+        write_accessor_data_to_file(
+            output_path / 'flattened.tif',
+            InMemoryDataAccessor(img)
+        )
\ No newline at end of file
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index a47986f8..14e3fd2d 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -83,6 +83,16 @@ def export_patches_from_multichannel_zstack(
     )
     ti.click('export_annotated_zstack')
 
+    # from extensions.chaeo.zmask import build_image_flattening_zmask_from_points
+    #
+    # dff = df[df['keeper'] == True]
+    # build_image_flattening_zmask_from_points(
+    #     dff['centroid-0'],
+    #     dff['centroid-1'],
+    #     dff['zi'],
+    #     stack.get_one_channel_data(patches_channel).data,
+    # )
+
     return {
         'pixel_model_id': px_model.model_id,
         'input_filepath': input_zstack_path,
diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py
index 66e01f95..6206c737 100644
--- a/extensions/chaeo/zmask.py
+++ b/extensions/chaeo/zmask.py
@@ -2,6 +2,8 @@ import numpy as np
 import pandas as pd
 
 from skimage.measure import find_contours, label, regionprops_table
+from sklearn.preprocessing import PolynomialFeatures
+from sklearn.linear_model import LinearRegression
 
 from model_server.accessors import GenericImageDataAccessor
 
@@ -143,4 +145,50 @@ def build_zmask_from_object_mask(
         'argmax': argmax,
     }
 
-    return zi_st, meta, df, interm
\ No newline at end of file
+    return zi_st, meta, df, interm
+
+
+def project_stack_from_focal_points(
+        xx: np.ndarray,
+        yy: np.ndarray,
+        zz: np.ndarray,
+        stack: GenericImageDataAccessor,
+        degree: int = 2,
+):
+    # TODO: add weights
+    """
+    Given a set of 3D points, project a multichannel z-stack
+    :param xx:
+    :param yy:
+    :param zz:
+    :param stack:
+    :param degree:
+    :return:
+    """
+    assert xx.shape == yy.shape
+    assert xx.shape == zz.shape
+    assert stack.chroma == 1
+
+    poly = PolynomialFeatures(degree=degree)
+    X = np.stack([xx, yy]).T
+    features = poly.fit_transform(X, zz)
+    model = LinearRegression(fit_intercept=False)
+    model.fit(features, zz)
+
+    output_shape = stack.hw
+    xy_indices = np.indices(output_shape).reshape(2, -1).T
+    xy_features = np.dot(
+        poly.fit_transform(xy_indices, zz),
+        model.coef_
+    )
+    zi_image = xy_features.reshape(
+        output_shape
+    ).round().clip(
+        0, stack.nz
+    ).astype('uint16')
+
+    return np.take_along_axis(
+        stack.data[:, :, 0, :],
+        np.expand_dims(zi_image, 2),
+        axis=2
+    ).reshape(output_shape)
-- 
GitLab