diff --git a/model_server/base/process.py b/model_server/base/process.py index 992c9fb93140ab49859f0807b55a1f8e321cd80d..94c6cf6206d4a0f1dcc4cfb78635807d58efd0c7 100644 --- a/model_server/base/process.py +++ b/model_server/base/process.py @@ -6,7 +6,7 @@ from math import ceil, floor import numpy as np import skimage from skimage.exposure import rescale_intensity - +from skimage.measure import find_contours def is_mask(img): """ @@ -117,6 +117,17 @@ def mask_largest_object( else: return img +def get_safe_contours(mask): + """ + Return a list of contour coordinates even if a mask is only one pixel across + """ + if mask.shape[0] == 1 or mask.shape[1] == 1: + c0 = mask.shape[0] - 1 + c1 = mask.shape[1] - 1 + return [np.array([(0, 0), (c0, c1)])] + else: + return find_contours(mask) + class Error(Exception): pass diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index bbe3a5f8208db4b4874df54d88525ffc162664a6..f87bb8c8dcaffaddd64eb628f6bca2ab1192beb1 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -16,7 +16,7 @@ from sklearn.linear_model import LinearRegression from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import InstanceSegmentationModel -from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb +from model_server.base.process import get_safe_contours, pad, rescale, resample_to_8bit, make_rgb from model_server.base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image from model_server.base.accessors import generate_file_accessor, PatchStack from model_server.base.process import mask_largest_object @@ -501,20 +501,17 @@ class RoiSet(object): if kwargs.get('draw_mask'): mci = kwargs.get('mask_channel', 0) - # mask = np.zeros(patch.shape[0:2], dtype=bool) - # mask[roi.relative_slice[0:2]] = roi.binary_mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] if kwargs.get('draw_contour'): mci = kwargs.get('contour_channel', 0) - # mask = np.zeros(patch.shape[0:2], dtype=bool) - # mask[roi.relative_slice[0:2]] = roi.binary_mask for zi in range(0, patch.shape[3]): + contours = get_safe_contours(mask) patch[:, :, mci, zi] = draw_contours_on_patch( patch[:, :, mci, zi], - find_contours(mask) + contours ) if pad_to and expanded: @@ -610,7 +607,7 @@ class RoiSet(object): try: ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname) bool_mask = ma_acc.data / np.iinfo(ma_acc.data.dtype).max - id_mask[sl] = r.label * bool_mask + id_mask[sl] = id_mask[sl] + r.label * bool_mask except Exception as e: raise DeserializeRoiSet(e) diff --git a/tests/test_process.py b/tests/test_process.py index 569ac1b797437381df4d80e3da749447ea52e91f..d2fb33b9cc9a8b6af04c8eacd99b6b87d7d64527 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1,9 +1,10 @@ import unittest import numpy as np +from skimage.measure import find_contours -from model_server.base.process import mask_largest_object -from model_server.base.process import pad +from model_server.base.annotators import draw_contours_on_patch +from model_server.base.process import get_safe_contours, mask_largest_object, pad class TestProcessingUtilityMethods(unittest.TestCase): def setUp(self) -> None: @@ -56,3 +57,26 @@ class TestMaskLargestObject(unittest.TestCase): self.assertTrue(np.all(np.unique(masked) == [0, 255])) self.assertTrue(np.all(masked[:, 3:5] == 0)) self.assertTrue(np.all(masked[3:5, :] == 0)) + + +class TestSafeContours(unittest.TestCase): + def setUp(self) -> None: + self.patch = np.ones((10, 20), dtype='uint8') + self.mask_ref = np.zeros((10, 20), dtype=bool) + self.mask_ref[0:5, 0:10] = True + self.mask_test = np.ones((1, 20), dtype=bool) + + def test_contours_on_compliant_mask(self): + con = get_safe_contours(self.mask_ref) + patch = self.patch.copy() + self.assertEqual((patch == 0).sum(), 0) + patch = draw_contours_on_patch(patch, con) + self.assertEqual((patch == 0).sum(), 14) + + def test_contours_on_noncompliant_mask(self): + con = get_safe_contours(self.mask_test) + patch = self.patch.copy() + self.assertEqual((patch == 0).sum(), 0) + patch = draw_contours_on_patch(self.patch, con) + self.assertEqual((patch == 0).sum(), 20) + self.assertEqual((patch[0, :] == 0).sum(), 20) \ No newline at end of file