From 6e8eefd62c1dd02c91ba74a01927e81d442d7e28 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 11 Apr 2024 11:43:25 +0200
Subject: [PATCH] RoiSet can passes multiple input channels to object
 classifier

---
 model_server/base/roiset.py | 10 +++++-----
 tests/test_accessors.py     | 10 +++++++++-
 tests/test_roiset.py        |  9 ++++++++-
 3 files changed, 22 insertions(+), 7 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index e8b3d04e..f0d977e8 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -258,9 +258,9 @@ class RoiSet(object):
             projected = self.acc_raw.data.max(axis=-1)
         return projected
 
-    def get_patches_acc(self, channel=None, **kwargs) -> PatchStack:  # padded, un-annotated 2d patches
-        if channel:
-            patches_df = self.get_patches(white_channel=channel, **kwargs)
+    def get_patches_acc(self, channels: list = None, **kwargs) -> PatchStack:  # padded, un-annotated 2d patches
+        if channels and len(channels) == 1:
+            patches_df = self.get_patches(white_channel=channels[0], **kwargs)
         else:
             patches_df = self.get_patches(**kwargs)
         return PatchStack(list(patches_df.patch))
@@ -308,9 +308,9 @@ class RoiSet(object):
         return zi_st
 
 
-    def classify_by(self, name: str, channels: int, object_classification_model: InstanceSegmentationModel, derived_channel_functions: dict[callable()] = None):
+    def classify_by(self, name: str, channels: int, object_classification_model: InstanceSegmentationModel, derived_channel_functions: dict[callable] = None):
 
-        raw_acc = self.get_patches_acc(expanded=False, pad_to=None)  # all channels
+        raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None)  # all channels
         if derived_channel_functions:
             mono_data = [raw_acc.get_one_channel_data(c).data for c in range(0, raw_acc.chroma)]
             for k, fcn in derived_channel_functions.keys():
diff --git a/tests/test_accessors.py b/tests/test_accessors.py
index 091b6684..a981a180 100644
--- a/tests/test_accessors.py
+++ b/tests/test_accessors.py
@@ -221,7 +221,7 @@ class TestPatchStackAccessor(unittest.TestCase):
         h = 512
         n = 4
         nz = 15
-        nc = 2
+        nc = 3
         acc = PatchStack(_random_int(n, h, w, nc, nz))
         self.assertEqual(acc.count, n)
         self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w))
@@ -235,6 +235,14 @@ class TestPatchStackAccessor(unittest.TestCase):
             self.assertEqual(mono.shape_dict[a], acc.shape_dict[a])
         self.assertEqual(mono.shape_dict['C'], 1)
 
+    def test_get_multiple_channels(self):
+        acc = self.test_pczyx()
+        channels = [0, 1]
+        mcacc = acc.get_channels(channels=channels)
+        for a in 'PXYZ':
+            self.assertEqual(mcacc.shape_dict[a], acc.shape_dict[a])
+        self.assertEqual(mcacc.shape_dict['C'], len(channels))
+
     def test_get_one_channel_mip(self):
         acc = self.test_pczyx()
         mono_mip = acc.get_one_channel_data(channel=1, mip=True)
diff --git a/tests/test_roiset.py b/tests/test_roiset.py
index 9d32668c..6d497259 100644
--- a/tests/test_roiset.py
+++ b/tests/test_roiset.py
@@ -195,7 +195,14 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def test_classify_by(self):
         roiset = self._make_roi_set()
-        roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel())
+        roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel())
+        self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
+        self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
+        return roiset
+
+    def test_classify_by_multiple_channels(self):
+        roiset = self._make_roi_set()
+        roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
         self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
         self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
         return roiset
-- 
GitLab