From 1092b9c55d898912c82ecb14ed64bee6b239b108 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 5 Oct 2023 15:25:29 +0200 Subject: [PATCH] Tested operation to flatten image based on focal points from body mask --- .../chaeo/examples/batch_run_patches.py | 2 +- extensions/chaeo/tests/test_zstack.py | 24 +++++++++ extensions/chaeo/workflows.py | 10 ++++ extensions/chaeo/zmask.py | 50 ++++++++++++++++++- 4 files changed, 84 insertions(+), 2 deletions(-) diff --git a/extensions/chaeo/examples/batch_run_patches.py b/extensions/chaeo/examples/batch_run_patches.py index cebc0b40..fb96be1a 100644 --- a/extensions/chaeo/examples/batch_run_patches.py +++ b/extensions/chaeo/examples/batch_run_patches.py @@ -51,7 +51,7 @@ if __name__ == '__main__': } result = export_patches_from_multichannel_zstack(**export_kwargs) - + break # parse and record results df = result['dataframe'] df['filename'] = ff.name diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py index f2394889..39da79b5 100644 --- a/extensions/chaeo/tests/test_zstack.py +++ b/extensions/chaeo/tests/test_zstack.py @@ -91,3 +91,27 @@ class TestZStackDerivedDataProducts(unittest.TestCase): make_3d=True) self.assertGreaterEqual(len(files), 1) + def test_flatten_image(self): + zmask, meta, df, interm = build_zmask_from_object_mask( + self.obmap.get_one_channel_data(0), + self.stack.get_one_channel_data(1), + mask_type='boxes', + ) + + from extensions.chaeo.zmask import project_stack_from_focal_points + + dff = df[df['keeper'] == True] + + img = project_stack_from_focal_points( + dff['centroid-0'].to_numpy(), + dff['centroid-1'].to_numpy(), + dff['zi'].to_numpy(), + self.stack.get_one_channel_data(1), + ) + + self.assertEqual(img.shape[0:2], self.stack.shape[0:2]) + + write_accessor_data_to_file( + output_path / 'flattened.tif', + InMemoryDataAccessor(img) + ) \ No newline at end of file diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py index a47986f8..14e3fd2d 100644 --- a/extensions/chaeo/workflows.py +++ b/extensions/chaeo/workflows.py @@ -83,6 +83,16 @@ def export_patches_from_multichannel_zstack( ) ti.click('export_annotated_zstack') + # from extensions.chaeo.zmask import build_image_flattening_zmask_from_points + # + # dff = df[df['keeper'] == True] + # build_image_flattening_zmask_from_points( + # dff['centroid-0'], + # dff['centroid-1'], + # dff['zi'], + # stack.get_one_channel_data(patches_channel).data, + # ) + return { 'pixel_model_id': px_model.model_id, 'input_filepath': input_zstack_path, diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py index 66e01f95..6206c737 100644 --- a/extensions/chaeo/zmask.py +++ b/extensions/chaeo/zmask.py @@ -2,6 +2,8 @@ import numpy as np import pandas as pd from skimage.measure import find_contours, label, regionprops_table +from sklearn.preprocessing import PolynomialFeatures +from sklearn.linear_model import LinearRegression from model_server.accessors import GenericImageDataAccessor @@ -143,4 +145,50 @@ def build_zmask_from_object_mask( 'argmax': argmax, } - return zi_st, meta, df, interm \ No newline at end of file + return zi_st, meta, df, interm + + +def project_stack_from_focal_points( + xx: np.ndarray, + yy: np.ndarray, + zz: np.ndarray, + stack: GenericImageDataAccessor, + degree: int = 2, +): + # TODO: add weights + """ + Given a set of 3D points, project a multichannel z-stack + :param xx: + :param yy: + :param zz: + :param stack: + :param degree: + :return: + """ + assert xx.shape == yy.shape + assert xx.shape == zz.shape + assert stack.chroma == 1 + + poly = PolynomialFeatures(degree=degree) + X = np.stack([xx, yy]).T + features = poly.fit_transform(X, zz) + model = LinearRegression(fit_intercept=False) + model.fit(features, zz) + + output_shape = stack.hw + xy_indices = np.indices(output_shape).reshape(2, -1).T + xy_features = np.dot( + poly.fit_transform(xy_indices, zz), + model.coef_ + ) + zi_image = xy_features.reshape( + output_shape + ).round().clip( + 0, stack.nz + ).astype('uint16') + + return np.take_along_axis( + stack.data[:, :, 0, :], + np.expand_dims(zi_image, 2), + axis=2 + ).reshape(output_shape) -- GitLab