diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 2cd56349041a313ba1d5d244790d24570c6b0d53..9a6de5e86bf9d13a1aa2d7e62af7ba974acd2920 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -189,7 +189,10 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
 
     def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict):
         assert segmentation_acc.is_mask()
-        assert input_acc.chroma == 1
+        if not input_acc.chroma == 1:
+            raise InvalidInputImageError('Object classifier expects only monochrome patches')
+        if not input_acc.nz == 1:
+            raise InvalidInputImageError('Object classifier expects only 2d patches')
 
         tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx')
         tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx')
@@ -205,14 +208,12 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
 
         assert len(obmaps) == 1, 'ilastik generated more than one object map'
 
-        # for some reason ilastik scrambles these axes to Z(1)YX(1)
-        assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
-        yxz = np.moveaxis(
-            obmaps[0][:, 0, :, :, 0],
-            [1, 2, 0],
-            [0, 1, 2]
+        # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C
+        assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1)
+        pyxcz = np.moveaxis(
+            obmaps[0],
+            [0, 1, 2, 3, 4],
+            [0, 4, 1, 2, 3]
         )
 
-        assert yxz.shape[0:2] == input_acc.hw
-        assert yxz.shape[2] == input_acc.nz
-        return PatchStack(data=yxz), {'success': True}
+        return PatchStack(data=pyxcz), {'success': True}
\ No newline at end of file
diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py
index c21e338d758a253392f19823bffd72504e7d0082..5f765ba709eb30364eda1e2c9fa60db57f5c0f69 100644
--- a/model_server/extensions/ilastik/tests/test_ilastik.py
+++ b/model_server/extensions/ilastik/tests/test_ilastik.py
@@ -292,5 +292,9 @@ class TestIlastikObjectClassification(unittest.TestCase):
     def test_classify_patches(self):
         raw_patches = self.roiset.get_raw_patches()
         patch_masks = self.roiset.get_patch_masks()
-        res = self.object_classifier.infer(raw_patches, patch_masks)
-        self.assertEqual(0, 1)
+        res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks)
+        self.assertEqual(res_patches.count, self.roiset.count)
+        for pi in range(0, res_patches.count):  # assert that there is only one nonzero label per patch
+            unique = np.unique(res_patches.iat(pi))
+            self.assertEqual(len(unique), 2)
+            self.assertEqual(unique[0], 0)