diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 2fe4515dcf95492538a75aa8b479dad901ee4f9e..ee14524a63608dff9c2b3b547bab0a98a5eec657 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -41,6 +41,9 @@ class GenericImageDataAccessor(ABC): def is_mask(self): return is_mask(self._data) + def can_mask(self, acc): + return self.is_mask() and self.shape == acc.get_mono(0).shape + def get_channels(self, channels: list, mip: bool = False): carr = [int(c) for c in channels] if mip: diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index be062316c7d44c93d5ac9bbb5a9e2bd59d55ad19..c8e82aca77a88befed42cfd93b7d0ed6f3ec3ec6 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -336,6 +336,23 @@ class TestPatchStackAccessor(unittest.TestCase): self.assertEqual(acc.get_mono(channel=0, mip=True).data_yx.shape, (n, h, w)) return acc + def test_can_mask(self): + w = 30 + h = 20 + n = 2 + nz = 3 + nc = 3 + acc1 = PatchStack(_random_int(n, h, w, nc, nz)) + acc2 = PatchStack(_random_int(n, 2*h, w, nc, nz)) + mask = PatchStack(_random_int(n, h, w, 1, nz) > 0.5) + self.assertFalse(acc1.is_mask()) + self.assertFalse(acc2.is_mask()) + self.assertTrue(mask.is_mask()) + self.assertFalse(acc1.can_mask(acc2)) + self.assertFalse(acc1.can_mask(acc2)) + self.assertTrue(mask.can_mask(acc1)) + self.assertFalse(mask.can_mask(acc2)) + def test_object_df(self): w = 30 h = 20 @@ -343,7 +360,7 @@ class TestPatchStackAccessor(unittest.TestCase): nz = 3 nc = 3 acc = PatchStack(_random_int(n, h, w, nc, nz)) - mask_data= np.zeros((n, h, w, nc, nz), dtype='uint8') + mask_data= np.zeros((n, h, w, 1, nz), dtype='uint8') mask_data[0, 0:5, 0:5, :, :] = 255 mask_data[1, 0:10, 0:10, :, :] = 255 mask = PatchStack(mask_data)