From ef72ee95b856cb128ae1e8693fea920f1ad5e2f7 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 1 Nov 2024 18:53:17 +0100
Subject: [PATCH] IntensityThresholdInstanceMaskSegmentationModel enforces a
 mono image but itself is not initiated with a channel

---
 model_server/base/accessors.py |  2 +-
 model_server/base/api.py       |  1 -
 model_server/base/roiset.py    | 11 +++++++----
 tests/base/test_accessors.py   |  4 ++--
 4 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index ee14524a..b250605d 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -525,7 +525,7 @@ class PatchStack(InMemoryDataAccessor):
         assuming the each patch in the patch stack represents a single object.
         :param mask of the same dimensions
         """
-        if self.shape != mask.shape or not mask.is_mask():
+        if not mask.can_mask(self):
             raise DataShapeError(f'Patch stack object dataframe expects a mask of the same dimensions')
         df = pd.DataFrame([
             {
diff --git a/model_server/base/api.py b/model_server/base/api.py
index 081b3b03..972c909e 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -89,7 +89,6 @@ def list_active_models():
 
 class BinaryThresholdSegmentationParams(BaseModel):
     tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation')
-    channel: Union[int, int] = Field(0, description='Channel from which to compute binary threshold')
 
 
 @app.put('/models/seg/threshold/load/')
diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 9a18a135..e8184712 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -1185,7 +1185,7 @@ class RoiSetWithDerivedChannels(RoiSet):
 
 
 class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel):
-    def __init__(self, tr: float = 0.5, channel: int = 0):
+    def __init__(self, tr: float = 0.5):
         """
         Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all
         objects as class 1 if threshold = 0.0
@@ -1193,9 +1193,8 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo
         :param channel: channel to use for thresholding
         """
         self.tr = tr
-        self.channel = channel
         self.loaded = self.load()
-        super().__init__(info={'tr': tr, 'channel': channel})
+        super().__init__(info={'tr': tr})
 
     def load(self):
         return True
@@ -1207,6 +1206,10 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo
             allow_3d: bool = False,
             connect_3d: bool = True,
     ) -> GenericImageDataAccessor:
+        if img.chroma != 1:
+            raise ShapeMismatchError(
+                f'IntensityThresholdInstanceMaskSegmentationModel expects 1 channel but received {img.chroma}'
+            )
         if isinstance(img, PatchStack):  # assume one object per patch
             df = img.get_object_df(mask)
             om = np.zeros(mask.shape, 'uint16')
@@ -1218,7 +1221,7 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo
             labels = get_label_ids(mask)
             df = pd.DataFrame(regionprops_table(
                 labels.data_yxz,
-                intensity_image=img.get_mono(self.channel).data_yxz,
+                intensity_image=img.data_yxz,
                 properties=('label', 'area', 'intensity_mean')
             ))
 
diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py
index c8e82aca..e73e3470 100644
--- a/tests/base/test_accessors.py
+++ b/tests/base/test_accessors.py
@@ -364,9 +364,9 @@ class TestPatchStackAccessor(unittest.TestCase):
         mask_data[0, 0:5, 0:5, :, :] = 255
         mask_data[1, 0:10, 0:10, :, :] = 255
         mask = PatchStack(mask_data)
-        df = acc.get_object_df(mask)
+        df = acc.get_mono(0).get_object_df(mask)
         # intensity means are centered around half of full range
-        self.assertTrue(np.all(((df['intensity_mean'] / acc.dtype_max) - 0.5)**2 < 1e-3))
+        self.assertTrue(np.all(((df['intensity_mean'] / acc.dtype_max) - 0.5)**2 < 1e-2))
         self.assertTrue(df['area'][1] / df['area'][0] == 4.0)
         return acc
 
-- 
GitLab