diff --git a/model_server/base/models.py b/model_server/base/models.py
index 4fa0849031bcd318a66f49c8843763b0a99c2279..a8576d40c2a1be0d1f2b44ac65eb0a8354b5243d 100644
--- a/model_server/base/models.py
+++ b/model_server/base/models.py
@@ -85,7 +85,7 @@ class InstanceSegmentationModel(ImageToImageModel):
         """
         if not mask.is_mask():
             raise InvalidInputImageError('Expecting a binary mask')
-        if not img.shape == mask.shape:
+        if img.hw != mask.hw or img.nz != mask.nz:
             raise InvalidInputImageError('Expect input image and mask to be the same shape')
 
     def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs):