From b54bfadd6c24920458b1da91f9936d1a04474c0e Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 19 Feb 2024 13:53:36 +0100 Subject: [PATCH] Test coverage of patch stack object classification --- model_server/extensions/ilastik/models.py | 21 ++++++++++--------- .../extensions/ilastik/tests/test_ilastik.py | 8 +++++-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 2cd56349..9a6de5e8 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -189,7 +189,10 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): assert segmentation_acc.is_mask() - assert input_acc.chroma == 1 + if not input_acc.chroma == 1: + raise InvalidInputImageError('Object classifier expects only monochrome patches') + if not input_acc.nz == 1: + raise InvalidInputImageError('Object classifier expects only 2d patches') tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') @@ -205,14 +208,12 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): assert len(obmaps) == 1, 'ilastik generated more than one object map' - # for some reason ilastik scrambles these axes to Z(1)YX(1) - assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1) - yxz = np.moveaxis( - obmaps[0][:, 0, :, :, 0], - [1, 2, 0], - [0, 1, 2] + # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C + assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1) + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] ) - assert yxz.shape[0:2] == input_acc.hw - assert yxz.shape[2] == input_acc.nz - return PatchStack(data=yxz), {'success': True} + return PatchStack(data=pyxcz), {'success': True} \ No newline at end of file diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index c21e338d..5f765ba7 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -292,5 +292,9 @@ class TestIlastikObjectClassification(unittest.TestCase): def test_classify_patches(self): raw_patches = self.roiset.get_raw_patches() patch_masks = self.roiset.get_patch_masks() - res = self.object_classifier.infer(raw_patches, patch_masks) - self.assertEqual(0, 1) + res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks) + self.assertEqual(res_patches.count, self.roiset.count) + for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch + unique = np.unique(res_patches.iat(pi)) + self.assertEqual(len(unique), 2) + self.assertEqual(unique[0], 0) -- GitLab