diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 48a18758ea486523cb6c38c31a7e7de695d10800..d17e69545c0aa746c3ebd17ddc89580ace56eadd 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -1,3 +1,4 @@
+import json
 import os
 from pathlib import Path
 
@@ -73,12 +74,31 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
     model_id = 'ilastik_pixel_classification'
     operations = ['segment', ]
 
+    @property
+    def model_shape_dict(self):
+        raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
+        ax = raw_info['axistags'][()]
+        ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']]
+        shape = raw_info['shape'][()]
+        return dict(zip(ax_keys, shape))
+
+    @property
+    def model_chroma(self):
+        return self.model_shape_dict['C']
+
+    @property
+    def model_3d(self):
+        return self.model_shape_dict['Z'] > 1
+
     @staticmethod
     def get_workflow():
         from ilastik.workflows import PixelClassificationWorkflow
         return PixelClassificationWorkflow
 
     def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
+            raise IlastikInputShapeError()
+
         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
         dsi = [
             {
@@ -221,4 +241,8 @@ class Error(Exception):
     pass
 
 class IlastikInputEmbedding(Error):
+    pass
+
+class IlastikInputShapeError(Error):
+    """Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape"""
     pass
\ No newline at end of file
diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py
index e81a057ff6934374ba583f7650c6b876c59829f2..7adde1a4cbfcfaa953ff927ca9c0d7f1b4221a9f 100644
--- a/model_server/extensions/ilastik/tests/test_ilastik.py
+++ b/model_server/extensions/ilastik/tests/test_ilastik.py
@@ -11,6 +11,9 @@ from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
 from model_server.base.workflows import classify_pixels
 from tests.test_api import TestServerBaseClass
 
+def _random_int(*args):
+    return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
+
 class TestIlastikPixelClassification(unittest.TestCase):
     def setUp(self) -> None:
         self.cf = CziImageFileAccessor(czifile['path'])
@@ -83,6 +86,40 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.mono_image = mono_image
         self.mask = mask
 
+    def test_pixel_classifier_enforces_input_shape(self):
+        model = ilm.IlastikPixelClassifierModel(
+            {'project_file': ilastik_classifiers['px']}
+        )
+        self.assertEqual(model.model_chroma, 1)
+        self.assertEqual(model.model_3d, False)
+
+        # correct data
+        self.assertIsInstance(
+            model.label_pixel_class(
+                InMemoryDataAccessor(
+                    _random_int(512, 256, 1, 1)
+                )
+            ),
+            InMemoryDataAccessor
+        )
+
+        # raise except with input of multiple channels
+        with self.assertRaises(ilm.IlastikInputShapeError):
+            mask = model.label_pixel_class(
+                InMemoryDataAccessor(
+                    _random_int(512, 256, 3, 1)
+                )
+            )
+
+        # raise except with input of multiple channels
+        with self.assertRaises(ilm.IlastikInputShapeError):
+            mask = model.label_pixel_class(
+                InMemoryDataAccessor(
+                    _random_int(512, 256, 1, 15)
+                )
+            )
+
+
     def test_run_object_classifier_from_pixel_predictions(self):
         self.test_run_pixel_classifier()
         fp = czifile['path']