diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index ee14524a63608dff9c2b3b547bab0a98a5eec657..b250605d4df5de6b1cd18d3a2570a85ed532f319 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -525,7 +525,7 @@ class PatchStack(InMemoryDataAccessor): assuming the each patch in the patch stack represents a single object. :param mask of the same dimensions """ - if self.shape != mask.shape or not mask.is_mask(): + if not mask.can_mask(self): raise DataShapeError(f'Patch stack object dataframe expects a mask of the same dimensions') df = pd.DataFrame([ { diff --git a/model_server/base/api.py b/model_server/base/api.py index 081b3b030c2c8e794a3287aa60d6dd9f27673101..972c909e22ed28c6bdd115d96c10be8ffa3bf9e3 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -89,7 +89,6 @@ def list_active_models(): class BinaryThresholdSegmentationParams(BaseModel): tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation') - channel: Union[int, int] = Field(0, description='Channel from which to compute binary threshold') @app.put('/models/seg/threshold/load/') diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 9a18a135c0c8fd91e892bb5ae0487548a673db77..e81847126e6ed78fa9808b8ae0af466ca760e3be 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1185,7 +1185,7 @@ class RoiSetWithDerivedChannels(RoiSet): class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel): - def __init__(self, tr: float = 0.5, channel: int = 0): + def __init__(self, tr: float = 0.5): """ Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all objects as class 1 if threshold = 0.0 @@ -1193,9 +1193,8 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo :param channel: channel to use for thresholding """ self.tr = tr - self.channel = channel self.loaded = self.load() - super().__init__(info={'tr': tr, 'channel': channel}) + super().__init__(info={'tr': tr}) def load(self): return True @@ -1207,6 +1206,10 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo allow_3d: bool = False, connect_3d: bool = True, ) -> GenericImageDataAccessor: + if img.chroma != 1: + raise ShapeMismatchError( + f'IntensityThresholdInstanceMaskSegmentationModel expects 1 channel but received {img.chroma}' + ) if isinstance(img, PatchStack): # assume one object per patch df = img.get_object_df(mask) om = np.zeros(mask.shape, 'uint16') @@ -1218,7 +1221,7 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo labels = get_label_ids(mask) df = pd.DataFrame(regionprops_table( labels.data_yxz, - intensity_image=img.get_mono(self.channel).data_yxz, + intensity_image=img.data_yxz, properties=('label', 'area', 'intensity_mean') )) diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index c8e82aca77a88befed42cfd93b7d0ed6f3ec3ec6..e73e347029f5ec42cd01696ebdf01a0ae4c0cff0 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -364,9 +364,9 @@ class TestPatchStackAccessor(unittest.TestCase): mask_data[0, 0:5, 0:5, :, :] = 255 mask_data[1, 0:10, 0:10, :, :] = 255 mask = PatchStack(mask_data) - df = acc.get_object_df(mask) + df = acc.get_mono(0).get_object_df(mask) # intensity means are centered around half of full range - self.assertTrue(np.all(((df['intensity_mean'] / acc.dtype_max) - 0.5)**2 < 1e-3)) + self.assertTrue(np.all(((df['intensity_mean'] / acc.dtype_max) - 0.5)**2 < 1e-2)) self.assertTrue(df['area'][1] / df['area'][0] == 4.0) return acc