From dae64e307583191026de750d38a9421f5d90dc71 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 25 Oct 2024 11:50:18 +0200 Subject: [PATCH] Can load binary segmentation model --- model_server/base/api.py | 12 +++++++++++- model_server/base/models.py | 25 ++++++++++++++++--------- tests/base/test_api.py | 8 +++++++- tests/base/test_model.py | 5 ++--- 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/model_server/base/api.py b/model_server/base/api.py index 608738a1..ba047efe 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,10 +1,11 @@ +from pydantic import BaseModel, Field from typing import Union from fastapi import FastAPI, HTTPException from .accessors import generate_file_accessor +from .models import BinaryThresholdSegmentationModel from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError - app = FastAPI(debug=True) from .pipelines.router import router @@ -68,6 +69,15 @@ def list_session_log() -> list: def list_active_models(): return session.describe_loaded_models() +class BinaryThresholdSegmentationParams(BaseModel): + channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.') + tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation') + +@app.put('/models/seg/binary_threshold/load/') +def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict: + result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p) + session.log_info(f'Loaded binary threshold segmentation model {result}') + return {'model_id': result} @app.get('/accessors') def list_accessors(): diff --git a/model_server/base/models.py b/model_server/base/models.py index eaded2c9..346bfaa4 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -1,15 +1,13 @@ from abc import ABC, abstractmethod -from math import floor import numpy as np -from pydantic import BaseModel -from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack +from .accessors import GenericImageDataAccessor, PatchStack class Model(ABC): - def __init__(self, autoload=True, params: BaseModel = None): + def __init__(self, autoload=True, params: dict = None): """ Abstract base class for an inference model that uses image data as an input. @@ -18,7 +16,7 @@ class Model(ABC): """ self.autoload = autoload if params: - self.params = params.dict() + self.params = params self.loaded = False if not autoload: return None @@ -130,15 +128,24 @@ class InstanceSegmentationModel(ImageToImageModel): class BinaryThresholdSegmentationModel(SemanticSegmentationModel): + """ + Trivial but functional model that labels all pixels above an intensity threshold as class 1 + """ - def __init__(self, tr: float = 0.5): - self.tr = tr + def __init__(self, params=None): + self.tr = params['tr'] + self.channel = params['channel'] + self.loaded = True def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): - return img.apply(lambda x: x > self.tr), {'success': True} + if self.channel: + acc = img.get_mono(self.channel) + else: + acc = img + return acc.get_mono(self.channel).apply(lambda x: x > self.tr) def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: - return self.infer(img, **kwargs)[0] + return self.infer(img, **kwargs) def load(self): pass diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 73a36188..9593b7f3 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -182,4 +182,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass): sd = self.assertGetSuccess(f'accessors/{acc_id}')['shape_dict'] self.assertEqual(self.assertGetSuccess(f'accessors/{acc_id}')['filepath'], '') acc_out = self.get_accessor(accessor_id=acc_id, filename='test_output.tif') - self.assertEqual(sd, acc_out.shape_dict) \ No newline at end of file + self.assertEqual(sd, acc_out.shape_dict) + + def test_load_binary_segmentation_model(self): + mid = self.assertPutSuccess( + '/models/seg/binary_threshold/load/', body={'channel': 0, 'tr': 10} + )['model_id'] + return mid \ No newline at end of file diff --git a/tests/base/test_model.py b/tests/base/test_model.py index d975f7cd..38e51c0b 100644 --- a/tests/base/test_model.py +++ b/tests/base/test_model.py @@ -56,9 +56,8 @@ class TestCziImageFileAccess(unittest.TestCase): return img, mask def test_binary_segmentation(self): - model = BinaryThresholdSegmentationModel(tr=3e4) - img = self.cf.get_mono(0) - res = model.label_pixel_class(img) + model = BinaryThresholdSegmentationModel({'tr': 3e4, 'channel': 0}) + res = model.label_pixel_class(self.cf) self.assertTrue(res.is_mask()) def test_dummy_instance_segmentation(self): -- GitLab