diff --git a/model_server/base/models.py b/model_server/base/models.py index 4fa0849031bcd318a66f49c8843763b0a99c2279..a8576d40c2a1be0d1f2b44ac65eb0a8354b5243d 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -85,7 +85,7 @@ class InstanceSegmentationModel(ImageToImageModel): """ if not mask.is_mask(): raise InvalidInputImageError('Expecting a binary mask') - if not img.shape == mask.shape: + if img.hw != mask.hw or img.nz != mask.nz: raise InvalidInputImageError('Expect input image and mask to be the same shape') def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs):