diff --git a/model_server/base/api.py b/model_server/base/api.py index 608738a166768fb1ff8420a8cd0b4a58b170a3be..ba047efe7a1d2a0e487fe8a6582fcf58be9ac670 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,10 +1,11 @@ +from pydantic import BaseModel, Field from typing import Union from fastapi import FastAPI, HTTPException from .accessors import generate_file_accessor +from .models import BinaryThresholdSegmentationModel from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError - app = FastAPI(debug=True) from .pipelines.router import router @@ -68,6 +69,15 @@ def list_session_log() -> list: def list_active_models(): return session.describe_loaded_models() +class BinaryThresholdSegmentationParams(BaseModel): + channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.') + tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation') + +@app.put('/models/seg/binary_threshold/load/') +def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict: + result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p) + session.log_info(f'Loaded binary threshold segmentation model {result}') + return {'model_id': result} @app.get('/accessors') def list_accessors(): diff --git a/model_server/base/models.py b/model_server/base/models.py index eaded2c9f24c8793600019d86a1dcbb5cdab7c52..346bfaa4d2a1df6c438b0d2689bbf17c09c8c570 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -1,15 +1,13 @@ from abc import ABC, abstractmethod -from math import floor import numpy as np -from pydantic import BaseModel -from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack +from .accessors import GenericImageDataAccessor, PatchStack class Model(ABC): - def __init__(self, autoload=True, params: BaseModel = None): + def __init__(self, autoload=True, params: dict = None): """ Abstract base class for an inference model that uses image data as an input. @@ -18,7 +16,7 @@ class Model(ABC): """ self.autoload = autoload if params: - self.params = params.dict() + self.params = params self.loaded = False if not autoload: return None @@ -130,15 +128,24 @@ class InstanceSegmentationModel(ImageToImageModel): class BinaryThresholdSegmentationModel(SemanticSegmentationModel): + """ + Trivial but functional model that labels all pixels above an intensity threshold as class 1 + """ - def __init__(self, tr: float = 0.5): - self.tr = tr + def __init__(self, params=None): + self.tr = params['tr'] + self.channel = params['channel'] + self.loaded = True def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): - return img.apply(lambda x: x > self.tr), {'success': True} + if self.channel: + acc = img.get_mono(self.channel) + else: + acc = img + return acc.get_mono(self.channel).apply(lambda x: x > self.tr) def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: - return self.infer(img, **kwargs)[0] + return self.infer(img, **kwargs) def load(self): pass diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 73a361880f3cb20fe0871c2374496cb83efb909d..9593b7f3a6b4748d7d228c89b0c619f3c38f0677 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -182,4 +182,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass): sd = self.assertGetSuccess(f'accessors/{acc_id}')['shape_dict'] self.assertEqual(self.assertGetSuccess(f'accessors/{acc_id}')['filepath'], '') acc_out = self.get_accessor(accessor_id=acc_id, filename='test_output.tif') - self.assertEqual(sd, acc_out.shape_dict) \ No newline at end of file + self.assertEqual(sd, acc_out.shape_dict) + + def test_load_binary_segmentation_model(self): + mid = self.assertPutSuccess( + '/models/seg/binary_threshold/load/', body={'channel': 0, 'tr': 10} + )['model_id'] + return mid \ No newline at end of file diff --git a/tests/base/test_model.py b/tests/base/test_model.py index d975f7cd8725e0215391b4a526feab7cd69eeb31..38e51c0b12650ffdbf70a86f0e7b6b5f54102fa3 100644 --- a/tests/base/test_model.py +++ b/tests/base/test_model.py @@ -56,9 +56,8 @@ class TestCziImageFileAccess(unittest.TestCase): return img, mask def test_binary_segmentation(self): - model = BinaryThresholdSegmentationModel(tr=3e4) - img = self.cf.get_mono(0) - res = model.label_pixel_class(img) + model = BinaryThresholdSegmentationModel({'tr': 3e4, 'channel': 0}) + res = model.label_pixel_class(self.cf) self.assertTrue(res.is_mask()) def test_dummy_instance_segmentation(self):