diff --git a/model_server/base/api.py b/model_server/base/api.py index ba047efe7a1d2a0e487fe8a6582fcf58be9ac670..f1f658135e7729822f86f11a99e50392f9b35242 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -69,16 +69,18 @@ def list_session_log() -> list: def list_active_models(): return session.describe_loaded_models() + class BinaryThresholdSegmentationParams(BaseModel): - channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.') tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation') + @app.put('/models/seg/binary_threshold/load/') def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict: result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p) session.log_info(f'Loaded binary threshold segmentation model {result}') return {'model_id': result} + @app.get('/accessors') def list_accessors(): return session.list_accessors() diff --git a/model_server/base/models.py b/model_server/base/models.py index 346bfaa4d2a1df6c438b0d2689bbf17c09c8c570..26e8af624eb1a98ae2067785b06a55b40dc45621 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -134,18 +134,13 @@ class BinaryThresholdSegmentationModel(SemanticSegmentationModel): def __init__(self, params=None): self.tr = params['tr'] - self.channel = params['channel'] self.loaded = True - def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): - if self.channel: - acc = img.get_mono(self.channel) - else: - acc = img - return acc.get_mono(self.channel).apply(lambda x: x > self.tr) + def infer(self, acc: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): + return acc.apply(lambda x: x > self.tr) - def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: - return self.infer(img, **kwargs) + def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + return self.infer(acc, **kwargs) def load(self): pass diff --git a/tests/base/test_model.py b/tests/base/test_model.py index 38e51c0b12650ffdbf70a86f0e7b6b5f54102fa3..8340111c5b45e764b76927c1534a64486d9cc2b6 100644 --- a/tests/base/test_model.py +++ b/tests/base/test_model.py @@ -56,7 +56,7 @@ class TestCziImageFileAccess(unittest.TestCase): return img, mask def test_binary_segmentation(self): - model = BinaryThresholdSegmentationModel({'tr': 3e4, 'channel': 0}) + model = BinaryThresholdSegmentationModel({'tr': 3e4}) res = model.label_pixel_class(self.cf) self.assertTrue(res.is_mask())