From ef72ee95b856cb128ae1e8693fea920f1ad5e2f7 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 1 Nov 2024 18:53:17 +0100 Subject: [PATCH] IntensityThresholdInstanceMaskSegmentationModel enforces a mono image but itself is not initiated with a channel --- model_server/base/accessors.py | 2 +- model_server/base/api.py | 1 - model_server/base/roiset.py | 11 +++++++---- tests/base/test_accessors.py | 4 ++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index ee14524a..b250605d 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -525,7 +525,7 @@ class PatchStack(InMemoryDataAccessor): assuming the each patch in the patch stack represents a single object. :param mask of the same dimensions """ - if self.shape != mask.shape or not mask.is_mask(): + if not mask.can_mask(self): raise DataShapeError(f'Patch stack object dataframe expects a mask of the same dimensions') df = pd.DataFrame([ { diff --git a/model_server/base/api.py b/model_server/base/api.py index 081b3b03..972c909e 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -89,7 +89,6 @@ def list_active_models(): class BinaryThresholdSegmentationParams(BaseModel): tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation') - channel: Union[int, int] = Field(0, description='Channel from which to compute binary threshold') @app.put('/models/seg/threshold/load/') diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 9a18a135..e8184712 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1185,7 +1185,7 @@ class RoiSetWithDerivedChannels(RoiSet): class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel): - def __init__(self, tr: float = 0.5, channel: int = 0): + def __init__(self, tr: float = 0.5): """ Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all objects as class 1 if threshold = 0.0 @@ -1193,9 +1193,8 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo :param channel: channel to use for thresholding """ self.tr = tr - self.channel = channel self.loaded = self.load() - super().__init__(info={'tr': tr, 'channel': channel}) + super().__init__(info={'tr': tr}) def load(self): return True @@ -1207,6 +1206,10 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo allow_3d: bool = False, connect_3d: bool = True, ) -> GenericImageDataAccessor: + if img.chroma != 1: + raise ShapeMismatchError( + f'IntensityThresholdInstanceMaskSegmentationModel expects 1 channel but received {img.chroma}' + ) if isinstance(img, PatchStack): # assume one object per patch df = img.get_object_df(mask) om = np.zeros(mask.shape, 'uint16') @@ -1218,7 +1221,7 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo labels = get_label_ids(mask) df = pd.DataFrame(regionprops_table( labels.data_yxz, - intensity_image=img.get_mono(self.channel).data_yxz, + intensity_image=img.data_yxz, properties=('label', 'area', 'intensity_mean') )) diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index c8e82aca..e73e3470 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -364,9 +364,9 @@ class TestPatchStackAccessor(unittest.TestCase): mask_data[0, 0:5, 0:5, :, :] = 255 mask_data[1, 0:10, 0:10, :, :] = 255 mask = PatchStack(mask_data) - df = acc.get_object_df(mask) + df = acc.get_mono(0).get_object_df(mask) # intensity means are centered around half of full range - self.assertTrue(np.all(((df['intensity_mean'] / acc.dtype_max) - 0.5)**2 < 1e-3)) + self.assertTrue(np.all(((df['intensity_mean'] / acc.dtype_max) - 0.5)**2 < 1e-2)) self.assertTrue(df['area'][1] / df['area'][0] == 4.0) return acc -- GitLab