From 9bce82686ca8a8833c4126cafccff7182b9cd69b Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 5 Oct 2023 15:38:40 +0200 Subject: [PATCH] Generalized to flattening multichannel images --- extensions/chaeo/tests/test_zstack.py | 2 +- extensions/chaeo/zmask.py | 35 ++++++++++++++------------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py index 39da79b5..1ac0321c 100644 --- a/extensions/chaeo/tests/test_zstack.py +++ b/extensions/chaeo/tests/test_zstack.py @@ -106,7 +106,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase): dff['centroid-0'].to_numpy(), dff['centroid-1'].to_numpy(), dff['zi'].to_numpy(), - self.stack.get_one_channel_data(1), + self.stack ) self.assertEqual(img.shape[0:2], self.stack.shape[0:2]) diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py index 6206c737..3ef3cd97 100644 --- a/extensions/chaeo/zmask.py +++ b/extensions/chaeo/zmask.py @@ -154,20 +154,18 @@ def project_stack_from_focal_points( zz: np.ndarray, stack: GenericImageDataAccessor, degree: int = 2, -): - # TODO: add weights +) -> np.ndarray: """ - Given a set of 3D points, project a multichannel z-stack - :param xx: - :param yy: - :param zz: - :param stack: - :param degree: - :return: + Given a set of 3D points, project a multichannel z-stack based on a surface fit of the provided points + :param xx: vector of point x-coordinates + :param yy: vector of point y-coordinates + :param zz: vector of point z-coordinates + :param stack: z-stack to project + :param degree: order of polynomial to fit + :return: multichannel 2d projected image array """ assert xx.shape == yy.shape assert xx.shape == zz.shape - assert stack.chroma == 1 poly = PolynomialFeatures(degree=degree) X = np.stack([xx, yy]).T @@ -175,20 +173,23 @@ def project_stack_from_focal_points( model = LinearRegression(fit_intercept=False) model.fit(features, zz) - output_shape = stack.hw - xy_indices = np.indices(output_shape).reshape(2, -1).T + xy_indices = np.indices(stack.hw).reshape(2, -1).T xy_features = np.dot( poly.fit_transform(xy_indices, zz), model.coef_ ) zi_image = xy_features.reshape( - output_shape + stack.hw ).round().clip( 0, stack.nz ).astype('uint16') return np.take_along_axis( - stack.data[:, :, 0, :], - np.expand_dims(zi_image, 2), - axis=2 - ).reshape(output_shape) + stack.data, + np.repeat( + np.expand_dims(zi_image, (2, 3)), + stack.chroma, + axis=2 + ), + axis=3 + ) -- GitLab