From 46366c5f909dc045d604c7a9bea3e24c670591fa Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 5 Apr 2024 14:17:36 +0200
Subject: [PATCH] Corrected error with overlapping bounding boxes, whereby a
 second box would overwrite a first; instead, masks are now just additive;
 also catch case when trying to draw contours from a mask that is a single
 pixel across

---
 model_server/base/process.py | 13 ++++++++++++-
 model_server/base/roiset.py  | 11 ++++-------
 tests/test_process.py        | 28 ++++++++++++++++++++++++++--
 3 files changed, 42 insertions(+), 10 deletions(-)

diff --git a/model_server/base/process.py b/model_server/base/process.py
index 992c9fb9..94c6cf62 100644
--- a/model_server/base/process.py
+++ b/model_server/base/process.py
@@ -6,7 +6,7 @@ from math import ceil, floor
 import numpy as np
 import skimage
 from skimage.exposure import rescale_intensity
-
+from skimage.measure import find_contours
 
 def is_mask(img):
     """
@@ -117,6 +117,17 @@ def mask_largest_object(
     else:
         return img
 
+def get_safe_contours(mask):
+    """
+    Return a list of contour coordinates even if a mask is only one pixel across
+    """
+    if mask.shape[0] == 1 or mask.shape[1] == 1:
+        c0 = mask.shape[0] - 1
+        c1 = mask.shape[1] - 1
+        return [np.array([(0, 0), (c0, c1)])]
+    else:
+        return find_contours(mask)
+
 
 class Error(Exception):
     pass
diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index bbe3a5f8..f87bb8c8 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -16,7 +16,7 @@ from sklearn.linear_model import LinearRegression
 
 from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.base.models import InstanceSegmentationModel
-from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb
+from model_server.base.process import get_safe_contours, pad, rescale, resample_to_8bit, make_rgb
 from model_server.base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
 from model_server.base.accessors import generate_file_accessor, PatchStack
 from model_server.base.process import mask_largest_object
@@ -501,20 +501,17 @@ class RoiSet(object):
 
             if kwargs.get('draw_mask'):
                 mci = kwargs.get('mask_channel', 0)
-                # mask = np.zeros(patch.shape[0:2], dtype=bool)
-                # mask[roi.relative_slice[0:2]] = roi.binary_mask
                 for zi in range(0, patch.shape[3]):
                     patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]
 
             if kwargs.get('draw_contour'):
                 mci = kwargs.get('contour_channel', 0)
-                # mask = np.zeros(patch.shape[0:2], dtype=bool)
-                # mask[roi.relative_slice[0:2]] = roi.binary_mask
 
                 for zi in range(0, patch.shape[3]):
+                    contours = get_safe_contours(mask)
                     patch[:, :, mci, zi] = draw_contours_on_patch(
                         patch[:, :, mci, zi],
-                        find_contours(mask)
+                        contours
                     )
 
             if pad_to and expanded:
@@ -610,7 +607,7 @@ class RoiSet(object):
             try:
                 ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname)
                 bool_mask = ma_acc.data / np.iinfo(ma_acc.data.dtype).max
-                id_mask[sl] = r.label * bool_mask
+                id_mask[sl] = id_mask[sl] + r.label * bool_mask
             except Exception as e:
                 raise DeserializeRoiSet(e)
 
diff --git a/tests/test_process.py b/tests/test_process.py
index 569ac1b7..d2fb33b9 100644
--- a/tests/test_process.py
+++ b/tests/test_process.py
@@ -1,9 +1,10 @@
 import unittest
 
 import numpy as np
+from skimage.measure import find_contours
 
-from model_server.base.process import mask_largest_object
-from model_server.base.process import pad
+from model_server.base.annotators import draw_contours_on_patch
+from model_server.base.process import get_safe_contours, mask_largest_object, pad
 
 class TestProcessingUtilityMethods(unittest.TestCase):
     def setUp(self) -> None:
@@ -56,3 +57,26 @@ class TestMaskLargestObject(unittest.TestCase):
         self.assertTrue(np.all(np.unique(masked) == [0, 255]))
         self.assertTrue(np.all(masked[:, 3:5] == 0))
         self.assertTrue(np.all(masked[3:5, :] == 0))
+
+
+class TestSafeContours(unittest.TestCase):
+    def setUp(self) -> None:
+        self.patch = np.ones((10, 20), dtype='uint8')
+        self.mask_ref = np.zeros((10, 20), dtype=bool)
+        self.mask_ref[0:5, 0:10] = True
+        self.mask_test = np.ones((1, 20), dtype=bool)
+
+    def test_contours_on_compliant_mask(self):
+        con = get_safe_contours(self.mask_ref)
+        patch = self.patch.copy()
+        self.assertEqual((patch == 0).sum(), 0)
+        patch = draw_contours_on_patch(patch, con)
+        self.assertEqual((patch == 0).sum(), 14)
+
+    def test_contours_on_noncompliant_mask(self):
+        con = get_safe_contours(self.mask_test)
+        patch = self.patch.copy()
+        self.assertEqual((patch == 0).sum(), 0)
+        patch = draw_contours_on_patch(self.patch, con)
+        self.assertEqual((patch == 0).sum(), 20)
+        self.assertEqual((patch[0, :] == 0).sum(), 20)
\ No newline at end of file
-- 
GitLab