diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 0cb6370cf217746b40f30bc32eb6ea329e4f74e9..4126e14e11e2591fe00c9e21b0430eda0eb7085f 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -36,7 +36,6 @@ class IlastikModel(Model):
             raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
 
         self.shell = None
-        self.axes = None
         super().__init__(autoload, params)
 
     def load(self):
@@ -65,18 +64,10 @@ class IlastikModel(Model):
             raise ParameterExpectedError(
                 f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
             )
-        self.axes = [
-            a['key'].upper() for a in json.loads(h5[f'Input Data/infos/lane0000/Raw Data/axistags'][()])['axes']
-        ]
         self.shell = shell
 
         return True
 
-
-class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
-    model_id = 'ilastik_pixel_classification'
-    operations = ['segment', ]
-
     @property
     def model_shape_dict(self):
         raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
@@ -97,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
     def model_3d(self):
         return self.model_shape_dict['Z'] > 1
 
+
+class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
+    model_id = 'ilastik_pixel_classification'
+    operations = ['segment', ]
+
     @staticmethod
     def get_workflow():
         from ilastik.workflows import PixelClassificationWorkflow
@@ -145,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
         return ObjectClassificationWorkflowBinary
 
     def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
+            raise IlastikInputShapeError()
+
         assert segmentation_img.is_mask()
         if isinstance(input_img, PatchStack):
             assert isinstance(segmentation_img, PatchStack)
@@ -202,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
         return ObjectClassificationWorkflowPrediction
 
     def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
+            raise IlastikInputShapeError()
+
         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
         tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')