From 6e8eefd62c1dd02c91ba74a01927e81d442d7e28 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 11 Apr 2024 11:43:25 +0200 Subject: [PATCH] RoiSet can passes multiple input channels to object classifier --- model_server/base/roiset.py | 10 +++++----- tests/test_accessors.py | 10 +++++++++- tests/test_roiset.py | 9 ++++++++- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index e8b3d04e..f0d977e8 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -258,9 +258,9 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected - def get_patches_acc(self, channel=None, **kwargs) -> PatchStack: # padded, un-annotated 2d patches - if channel: - patches_df = self.get_patches(white_channel=channel, **kwargs) + def get_patches_acc(self, channels: list = None, **kwargs) -> PatchStack: # padded, un-annotated 2d patches + if channels and len(channels) == 1: + patches_df = self.get_patches(white_channel=channels[0], **kwargs) else: patches_df = self.get_patches(**kwargs) return PatchStack(list(patches_df.patch)) @@ -308,9 +308,9 @@ class RoiSet(object): return zi_st - def classify_by(self, name: str, channels: int, object_classification_model: InstanceSegmentationModel, derived_channel_functions: dict[callable()] = None): + def classify_by(self, name: str, channels: int, object_classification_model: InstanceSegmentationModel, derived_channel_functions: dict[callable] = None): - raw_acc = self.get_patches_acc(expanded=False, pad_to=None) # all channels + raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels if derived_channel_functions: mono_data = [raw_acc.get_one_channel_data(c).data for c in range(0, raw_acc.chroma)] for k, fcn in derived_channel_functions.keys(): diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 091b6684..a981a180 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -221,7 +221,7 @@ class TestPatchStackAccessor(unittest.TestCase): h = 512 n = 4 nz = 15 - nc = 2 + nc = 3 acc = PatchStack(_random_int(n, h, w, nc, nz)) self.assertEqual(acc.count, n) self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w)) @@ -235,6 +235,14 @@ class TestPatchStackAccessor(unittest.TestCase): self.assertEqual(mono.shape_dict[a], acc.shape_dict[a]) self.assertEqual(mono.shape_dict['C'], 1) + def test_get_multiple_channels(self): + acc = self.test_pczyx() + channels = [0, 1] + mcacc = acc.get_channels(channels=channels) + for a in 'PXYZ': + self.assertEqual(mcacc.shape_dict[a], acc.shape_dict[a]) + self.assertEqual(mcacc.shape_dict['C'], len(channels)) + def test_get_one_channel_mip(self): acc = self.test_pczyx() mono_mip = acc.get_one_channel_data(channel=1, mip=True) diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 9d32668c..6d497259 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -195,7 +195,14 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_classify_by(self): roiset = self._make_roi_set() - roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel()) + roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel()) + self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) + self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) + return roiset + + def test_classify_by_multiple_channels(self): + roiset = self._make_roi_set() + roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) return roiset -- GitLab