diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 2cd56349041a313ba1d5d244790d24570c6b0d53..9a6de5e86bf9d13a1aa2d7e62af7ba974acd2920 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -189,7 +189,10 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): assert segmentation_acc.is_mask() - assert input_acc.chroma == 1 + if not input_acc.chroma == 1: + raise InvalidInputImageError('Object classifier expects only monochrome patches') + if not input_acc.nz == 1: + raise InvalidInputImageError('Object classifier expects only 2d patches') tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') @@ -205,14 +208,12 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): assert len(obmaps) == 1, 'ilastik generated more than one object map' - # for some reason ilastik scrambles these axes to Z(1)YX(1) - assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1) - yxz = np.moveaxis( - obmaps[0][:, 0, :, :, 0], - [1, 2, 0], - [0, 1, 2] + # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C + assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1) + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] ) - assert yxz.shape[0:2] == input_acc.hw - assert yxz.shape[2] == input_acc.nz - return PatchStack(data=yxz), {'success': True} + return PatchStack(data=pyxcz), {'success': True} \ No newline at end of file diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index c21e338d758a253392f19823bffd72504e7d0082..5f765ba709eb30364eda1e2c9fa60db57f5c0f69 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -292,5 +292,9 @@ class TestIlastikObjectClassification(unittest.TestCase): def test_classify_patches(self): raw_patches = self.roiset.get_raw_patches() patch_masks = self.roiset.get_patch_masks() - res = self.object_classifier.infer(raw_patches, patch_masks) - self.assertEqual(0, 1) + res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks) + self.assertEqual(res_patches.count, self.roiset.count) + for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch + unique = np.unique(res_patches.iat(pi)) + self.assertEqual(len(unique), 2) + self.assertEqual(unique[0], 0)