From b54bfadd6c24920458b1da91f9936d1a04474c0e Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 19 Feb 2024 13:53:36 +0100
Subject: [PATCH] Test coverage of patch stack object classification

---
 model_server/extensions/ilastik/models.py     | 21 ++++++++++---------
 .../extensions/ilastik/tests/test_ilastik.py  |  8 +++++--
 2 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 2cd56349..9a6de5e8 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -189,7 +189,10 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
 
     def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict):
         assert segmentation_acc.is_mask()
-        assert input_acc.chroma == 1
+        if not input_acc.chroma == 1:
+            raise InvalidInputImageError('Object classifier expects only monochrome patches')
+        if not input_acc.nz == 1:
+            raise InvalidInputImageError('Object classifier expects only 2d patches')
 
         tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx')
         tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx')
@@ -205,14 +208,12 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
 
         assert len(obmaps) == 1, 'ilastik generated more than one object map'
 
-        # for some reason ilastik scrambles these axes to Z(1)YX(1)
-        assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
-        yxz = np.moveaxis(
-            obmaps[0][:, 0, :, :, 0],
-            [1, 2, 0],
-            [0, 1, 2]
+        # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C
+        assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1)
+        pyxcz = np.moveaxis(
+            obmaps[0],
+            [0, 1, 2, 3, 4],
+            [0, 4, 1, 2, 3]
         )
 
-        assert yxz.shape[0:2] == input_acc.hw
-        assert yxz.shape[2] == input_acc.nz
-        return PatchStack(data=yxz), {'success': True}
+        return PatchStack(data=pyxcz), {'success': True}
\ No newline at end of file
diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py
index c21e338d..5f765ba7 100644
--- a/model_server/extensions/ilastik/tests/test_ilastik.py
+++ b/model_server/extensions/ilastik/tests/test_ilastik.py
@@ -292,5 +292,9 @@ class TestIlastikObjectClassification(unittest.TestCase):
     def test_classify_patches(self):
         raw_patches = self.roiset.get_raw_patches()
         patch_masks = self.roiset.get_patch_masks()
-        res = self.object_classifier.infer(raw_patches, patch_masks)
-        self.assertEqual(0, 1)
+        res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks)
+        self.assertEqual(res_patches.count, self.roiset.count)
+        for pi in range(0, res_patches.count):  # assert that there is only one nonzero label per patch
+            unique = np.unique(res_patches.iat(pi))
+            self.assertEqual(len(unique), 2)
+            self.assertEqual(unique[0], 0)
-- 
GitLab