diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index 2fe4515dcf95492538a75aa8b479dad901ee4f9e..ee14524a63608dff9c2b3b547bab0a98a5eec657 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -41,6 +41,9 @@ class GenericImageDataAccessor(ABC):
     def is_mask(self):
         return is_mask(self._data)
 
+    def can_mask(self, acc):
+        return self.is_mask() and self.shape == acc.get_mono(0).shape
+
     def get_channels(self, channels: list, mip: bool = False):
         carr = [int(c) for c in channels]
         if mip:
diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py
index be062316c7d44c93d5ac9bbb5a9e2bd59d55ad19..c8e82aca77a88befed42cfd93b7d0ed6f3ec3ec6 100644
--- a/tests/base/test_accessors.py
+++ b/tests/base/test_accessors.py
@@ -336,6 +336,23 @@ class TestPatchStackAccessor(unittest.TestCase):
         self.assertEqual(acc.get_mono(channel=0, mip=True).data_yx.shape, (n, h, w))
         return acc
 
+    def test_can_mask(self):
+        w = 30
+        h = 20
+        n = 2
+        nz = 3
+        nc = 3
+        acc1 = PatchStack(_random_int(n, h, w, nc, nz))
+        acc2 = PatchStack(_random_int(n, 2*h, w, nc, nz))
+        mask = PatchStack(_random_int(n, h, w, 1, nz) > 0.5)
+        self.assertFalse(acc1.is_mask())
+        self.assertFalse(acc2.is_mask())
+        self.assertTrue(mask.is_mask())
+        self.assertFalse(acc1.can_mask(acc2))
+        self.assertFalse(acc1.can_mask(acc2))
+        self.assertTrue(mask.can_mask(acc1))
+        self.assertFalse(mask.can_mask(acc2))
+
     def test_object_df(self):
         w = 30
         h = 20
@@ -343,7 +360,7 @@ class TestPatchStackAccessor(unittest.TestCase):
         nz = 3
         nc = 3
         acc = PatchStack(_random_int(n, h, w, nc, nz))
-        mask_data= np.zeros((n, h, w, nc, nz), dtype='uint8')
+        mask_data= np.zeros((n, h, w, 1, nz), dtype='uint8')
         mask_data[0, 0:5, 0:5, :, :] = 255
         mask_data[1, 0:10, 0:10, :, :] = 255
         mask = PatchStack(mask_data)