diff --git a/model_server/base/models.py b/model_server/base/models.py
index 4fa0849031bcd318a66f49c8843763b0a99c2279..a8576d40c2a1be0d1f2b44ac65eb0a8354b5243d 100644
--- a/model_server/base/models.py
+++ b/model_server/base/models.py
@@ -85,7 +85,7 @@ class InstanceSegmentationModel(ImageToImageModel):
         """
         if not mask.is_mask():
             raise InvalidInputImageError('Expecting a binary mask')
-        if not img.shape == mask.shape:
+        if img.hw != mask.hw or img.nz != mask.nz:
             raise InvalidInputImageError('Expect input image and mask to be the same shape')
 
     def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs):
diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 7229b43eb601dc0703499392403e42e1bce3bf98..4126e14e11e2591fe00c9e21b0430eda0eb7085f 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -53,14 +53,13 @@ class IlastikModel(Model):
         shell = app.main(args, init_logging=False)
 
         # validate if inputs are embedded in project file
-        input_groups = shell.projectManager.currentProjectFile['Input Data']['infos']
-        lanes = input_groups.keys()
-        for ll in lanes:
-            input_types = input_groups[ll]
-            for tt in input_types:
-                ds_loc = input_groups[ll][tt].get('location', False)
-                if self.enforce_embedded and ds_loc and ds_loc[()] == b'FileSystem':
+        h5 = shell.projectManager.currentProjectFile
+        for lane in h5['Input Data/infos'].keys():
+            for role in h5[f'Input Data/infos/{lane}'].keys():
+                grp = h5[f'Input Data/infos/{lane}/{role}']
+                if self.enforce_embedded and ('location' in grp.keys()) and grp['location'][()] != b'ProjectInternal':
                     raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem')
+            assert True
         if not isinstance(shell.workflow, self.get_workflow()):
             raise ParameterExpectedError(
                 f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
@@ -69,11 +68,6 @@ class IlastikModel(Model):
 
         return True
 
-
-class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
-    model_id = 'ilastik_pixel_classification'
-    operations = ['segment', ]
-
     @property
     def model_shape_dict(self):
         raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
@@ -94,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
     def model_3d(self):
         return self.model_shape_dict['Z'] > 1
 
+
+class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
+    model_id = 'ilastik_pixel_classification'
+    operations = ['segment', ]
+
     @staticmethod
     def get_workflow():
         from ilastik.workflows import PixelClassificationWorkflow
@@ -142,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
         return ObjectClassificationWorkflowBinary
 
     def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
+            raise IlastikInputShapeError()
+
         assert segmentation_img.is_mask()
         if isinstance(input_img, PatchStack):
             assert isinstance(segmentation_img, PatchStack)
@@ -199,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
         return ObjectClassificationWorkflowPrediction
 
     def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
+            raise IlastikInputShapeError()
+
         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
         tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')