From 6735a2f784f9baf0f12c61e0711aa4fdab9e6f2a Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 12 Aug 2024 10:40:36 +0200 Subject: [PATCH] Parameters auto-validates anything ending in accessor_id and model_id; call_pipeline() maps these to objects and sends to pipeline function as dicts of respective objects --- model_server/base/pipelines/params.py | 38 ++++++++++++++------------ model_server/base/pipelines/segment.py | 22 +++++++++++---- model_server/base/pipelines/util.py | 19 +++++++++---- tests/base/test_api.py | 31 +++++++++++++++++---- tests/base/test_pipelines.py | 2 ++ 5 files changed, 78 insertions(+), 34 deletions(-) diff --git a/model_server/base/pipelines/params.py b/model_server/base/pipelines/params.py index 9f1bda03..d07d68fe 100644 --- a/model_server/base/pipelines/params.py +++ b/model_server/base/pipelines/params.py @@ -1,31 +1,33 @@ from typing import List, Union from fastapi import HTTPException -from pydantic import BaseModel, Field, validator - +from pydantic import BaseModel, Field, root_validator from ..session import session, AccessorIdError -class SingleModelPipelineParams(BaseModel): +class PipelineParams(BaseModel): accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)') keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session') - @validator('model_id') - def models_are_loaded(cls, model_id): - if model_id not in session.describe_loaded_models().keys(): - raise HTTPException(status_code=409, detail=f'Model with ID {model_id} has not been loaded') - return model_id - - @validator('accessor_id') - def accessors_are_loaded(cls, accessor_id): - try: - info = session.get_accessor_info(accessor_id) - except AccessorIdError as e: - raise HTTPException(status_code=409, detail=str(e)) - if not info['loaded']: - raise HTTPException(status_code=409, detail=f'Accessor with ID {accessor_id} is not loaded') - return accessor_id + @root_validator(pre=False) + def models_are_loaded(cls, dd): + for k, v in dd.items(): + if k.endswith('model_id') and v not in session.describe_loaded_models().keys(): + raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded') + return dd + + @root_validator(pre=False) + def accessors_are_loaded(cls, dd): + for k, v in dd.items(): + if k.endswith('accessor_id'): + try: + info = session.get_accessor_info(v) + except AccessorIdError as e: + raise HTTPException(status_code=409, detail=str(e)) + if not info['loaded']: + raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded') + return dd class PipelineRecord(BaseModel): diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py index 3dfa86d3..2e6fde6c 100644 --- a/model_server/base/pipelines/segment.py +++ b/model_server/base/pipelines/segment.py @@ -1,9 +1,11 @@ +from typing import Dict + from fastapi import APIRouter from .util import call_pipeline from ..accessors import GenericImageDataAccessor -from ..models import SemanticSegmentationModel -from .params import SingleModelPipelineParams, PipelineRecord +from ..models import Model +from .params import PipelineParams, PipelineRecord from ..process import smooth from ..util import PipelineTrace @@ -14,7 +16,9 @@ router = APIRouter( tags=['pipelines'], ) -class SegmentParams(SingleModelPipelineParams): +class SegmentParams(PipelineParams): + accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') + model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)') channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.') @@ -25,11 +29,17 @@ def segment(p: SegmentParams) -> PipelineRecord: """ return call_pipeline(segment_pipeline, p) -def segment_pipeline(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace: +def segment_pipeline( + accessors: Dict[str, GenericImageDataAccessor], + models: Dict[str, Model], + **k +) -> PipelineTrace: + input_accessor = accessors.get('accessor') + model = models.get('model') d = PipelineTrace() - d['input'] = acc_in + d['input'] = input_accessor if ch := k.get('channel') is not None: - d['mono'] = acc_in.get_mono(ch) + d['mono'] = input_accessor.get_mono(ch) d['inference'] = model.label_pixel_class(d.last) if sm := k.get('smooth') is not None: d['smooth'] = d.last.apply(lambda x: smooth(x, sm)) diff --git a/model_server/base/pipelines/util.py b/model_server/base/pipelines/util.py index a4a37feb..50459da4 100644 --- a/model_server/base/pipelines/util.py +++ b/model_server/base/pipelines/util.py @@ -1,13 +1,22 @@ -from .params import SingleModelPipelineParams, PipelineRecord +from .params import PipelineParams, PipelineRecord from ..session import session -def call_pipeline(func, p: SingleModelPipelineParams): - acc_in = session.get_accessor(p.accessor_id, pop=True) +def call_pipeline(func, p: PipelineParams): + + accessors_in = {} + for k, v in p.dict().items(): + if k.endswith('accessor_id'): + accessors_in[k.split('_id')[0]] = session.get_accessor(v, pop=True) + + models = {} + for k, v in p.dict().items(): + if k.endswith('model_id'): + models[k.split('_id')[0]] = session.models[v]['object'] steps = func( - acc_in, - session.models[p.model_id]['object'], + accessors_in, + models, **p.dict(), ) diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 35476d2b..bf5bc87e 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -127,12 +127,33 @@ class TestApiFromAutomatedClient(TestServerTestCase): def test_pipeline_errors_when_ids_not_found(self): - model_id = 'not_a_real_model' - resp = self._put( - f'pipelines/segment', - body={'model_id': model_id, 'accessor_id': 'fake'} + self.copy_input_file_to_server() + model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id'] + in_acc_id = self._put( + f'accessors/read_from_file', + query={ + 'filename': czifile['name'], + }, + ).json() + + # respond with 409 for invalid accessor_id + self.assertEqual( + self._put( + f'pipelines/segment', + body={'model_id': model_id, 'accessor_id': 'fake'} + ).status_code, + 409 ) - self.assertEqual(resp.status_code, 409, resp.content.decode()) + + # respond with 409 for invalid model_id + self.assertEqual( + self._put( + f'pipelines/segment', + body={'model_id': 'fake', 'accessor_id': in_acc_id} + ).status_code, + 409 + ) + def test_i2i_dummy_inference_by_api(self): self.copy_input_file_to_server() diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py index f900f941..bbcd2305 100644 --- a/tests/base/test_pipelines.py +++ b/tests/base/test_pipelines.py @@ -3,12 +3,14 @@ import unittest from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file from model_server.base.pipelines import segment + import model_server.conf.testing as conf from tests.base.test_model import DummySemanticSegmentationModel czifile = conf.meta['image_files']['czifile'] output_path = conf.meta['output_path'] + class TestSegmentationPipeline(unittest.TestCase): def setUp(self) -> None: self.model = DummySemanticSegmentationModel() -- GitLab