From a3edc3246c0ec1623adb4a282c5f2ad206d7c6e7 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 25 Mar 2024 13:04:19 +0100
Subject: [PATCH] Correct error when trying to extract a single channel from a
 PatchStack

---
 model_server/base/accessors.py |  7 +++++++
 tests/test_accessors.py        | 15 +++++++++++++++
 2 files changed, 22 insertions(+)

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index 8cba54bc..12807bf2 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -308,6 +308,13 @@ class PatchStack(InMemoryDataAccessor):
         else:
             tifffile.imwrite(fpath, tzcyx, imagej=True)
 
+    def get_one_channel_data(self, channel: int, mip: bool = False):
+        c = int(channel)
+        if mip:
+            return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :].max(axis=-1, keepdims=True))
+        else:
+            return PatchStack(self.pyxcz[:, :, :, c:(c + 1), :])
+
     @property
     def shape_dict(self):
         return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
diff --git a/tests/test_accessors.py b/tests/test_accessors.py
index d2ca777c..e84788d2 100644
--- a/tests/test_accessors.py
+++ b/tests/test_accessors.py
@@ -200,6 +200,21 @@ class TestPatchStackAccessor(unittest.TestCase):
         self.assertEqual(acc.hw, (h, w))
         return acc
 
+    def test_get_one_channel(self):
+        acc = self.test_pczyx()
+        mono = acc.get_one_channel_data(channel=1)
+        for a in 'PXYZ':
+            self.assertEqual(mono.shape_dict[a], acc.shape_dict[a])
+        self.assertEqual(mono.shape_dict['C'], 1)
+
+    def test_get_one_channel_mip(self):
+        acc = self.test_pczyx()
+        mono_mip = acc.get_one_channel_data(channel=1, mip=True)
+        for a in 'PXY':
+            self.assertEqual(mono_mip.shape_dict[a], acc.shape_dict[a])
+        for a in 'CZ':
+            self.assertEqual(mono_mip.shape_dict[a], 1)
+
     def test_export_pczyx_patch_hyperstack(self):
         acc = self.test_pczyx()
         fp = output_path / 'patch_hyperstack.tif'
-- 
GitLab