diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py
index 39da79b5a46e72fc500015a2abbf98fdb7c5a0e6..1ac0321cd999cc692c6b9ea4dd94e58c459fb389 100644
--- a/extensions/chaeo/tests/test_zstack.py
+++ b/extensions/chaeo/tests/test_zstack.py
@@ -106,7 +106,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             dff['centroid-0'].to_numpy(),
             dff['centroid-1'].to_numpy(),
             dff['zi'].to_numpy(),
-            self.stack.get_one_channel_data(1),
+            self.stack
         )
 
         self.assertEqual(img.shape[0:2], self.stack.shape[0:2])
diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py
index 6206c7371c4aa0505ca198a194f93335a5feacfa..3ef3cd97aa422f03bd9673f1a83aaff0a3de8fa8 100644
--- a/extensions/chaeo/zmask.py
+++ b/extensions/chaeo/zmask.py
@@ -154,20 +154,18 @@ def project_stack_from_focal_points(
         zz: np.ndarray,
         stack: GenericImageDataAccessor,
         degree: int = 2,
-):
-    # TODO: add weights
+) -> np.ndarray:
     """
-    Given a set of 3D points, project a multichannel z-stack
-    :param xx:
-    :param yy:
-    :param zz:
-    :param stack:
-    :param degree:
-    :return:
+    Given a set of 3D points, project a multichannel z-stack based on a surface fit of the provided points
+    :param xx: vector of point x-coordinates
+    :param yy: vector of point y-coordinates
+    :param zz: vector of point z-coordinates
+    :param stack: z-stack to project
+    :param degree: order of polynomial to fit
+    :return: multichannel 2d projected image array
     """
     assert xx.shape == yy.shape
     assert xx.shape == zz.shape
-    assert stack.chroma == 1
 
     poly = PolynomialFeatures(degree=degree)
     X = np.stack([xx, yy]).T
@@ -175,20 +173,23 @@ def project_stack_from_focal_points(
     model = LinearRegression(fit_intercept=False)
     model.fit(features, zz)
 
-    output_shape = stack.hw
-    xy_indices = np.indices(output_shape).reshape(2, -1).T
+    xy_indices = np.indices(stack.hw).reshape(2, -1).T
     xy_features = np.dot(
         poly.fit_transform(xy_indices, zz),
         model.coef_
     )
     zi_image = xy_features.reshape(
-        output_shape
+        stack.hw
     ).round().clip(
         0, stack.nz
     ).astype('uint16')
 
     return np.take_along_axis(
-        stack.data[:, :, 0, :],
-        np.expand_dims(zi_image, 2),
-        axis=2
-    ).reshape(output_shape)
+        stack.data,
+        np.repeat(
+            np.expand_dims(zi_image, (2, 3)),
+            stack.chroma,
+            axis=2
+        ),
+        axis=3
+    )