Skip to content
Snippets Groups Projects
Commit 6735a2f7 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Parameters auto-validates anything ending in accessor_id and model_id;...

Parameters auto-validates anything ending in accessor_id and model_id; call_pipeline() maps these to objects and sends to pipeline function as dicts of respective objects
parent 0c1677ec
No related branches found
No related tags found
No related merge requests found
from typing import List, Union from typing import List, Union
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, root_validator
from ..session import session, AccessorIdError from ..session import session, AccessorIdError
class SingleModelPipelineParams(BaseModel): class PipelineParams(BaseModel):
accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)') model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session') keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
@validator('model_id') @root_validator(pre=False)
def models_are_loaded(cls, model_id): def models_are_loaded(cls, dd):
if model_id not in session.describe_loaded_models().keys(): for k, v in dd.items():
raise HTTPException(status_code=409, detail=f'Model with ID {model_id} has not been loaded') if k.endswith('model_id') and v not in session.describe_loaded_models().keys():
return model_id raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded')
return dd
@validator('accessor_id')
def accessors_are_loaded(cls, accessor_id): @root_validator(pre=False)
try: def accessors_are_loaded(cls, dd):
info = session.get_accessor_info(accessor_id) for k, v in dd.items():
except AccessorIdError as e: if k.endswith('accessor_id'):
raise HTTPException(status_code=409, detail=str(e)) try:
if not info['loaded']: info = session.get_accessor_info(v)
raise HTTPException(status_code=409, detail=f'Accessor with ID {accessor_id} is not loaded') except AccessorIdError as e:
return accessor_id raise HTTPException(status_code=409, detail=str(e))
if not info['loaded']:
raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded')
return dd
class PipelineRecord(BaseModel): class PipelineRecord(BaseModel):
......
from typing import Dict
from fastapi import APIRouter from fastapi import APIRouter
from .util import call_pipeline from .util import call_pipeline
from ..accessors import GenericImageDataAccessor from ..accessors import GenericImageDataAccessor
from ..models import SemanticSegmentationModel from ..models import Model
from .params import SingleModelPipelineParams, PipelineRecord from .params import PipelineParams, PipelineRecord
from ..process import smooth from ..process import smooth
from ..util import PipelineTrace from ..util import PipelineTrace
...@@ -14,7 +16,9 @@ router = APIRouter( ...@@ -14,7 +16,9 @@ router = APIRouter(
tags=['pipelines'], tags=['pipelines'],
) )
class SegmentParams(SingleModelPipelineParams): class SegmentParams(PipelineParams):
accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.') channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
...@@ -25,11 +29,17 @@ def segment(p: SegmentParams) -> PipelineRecord: ...@@ -25,11 +29,17 @@ def segment(p: SegmentParams) -> PipelineRecord:
""" """
return call_pipeline(segment_pipeline, p) return call_pipeline(segment_pipeline, p)
def segment_pipeline(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace: def segment_pipeline(
accessors: Dict[str, GenericImageDataAccessor],
models: Dict[str, Model],
**k
) -> PipelineTrace:
input_accessor = accessors.get('accessor')
model = models.get('model')
d = PipelineTrace() d = PipelineTrace()
d['input'] = acc_in d['input'] = input_accessor
if ch := k.get('channel') is not None: if ch := k.get('channel') is not None:
d['mono'] = acc_in.get_mono(ch) d['mono'] = input_accessor.get_mono(ch)
d['inference'] = model.label_pixel_class(d.last) d['inference'] = model.label_pixel_class(d.last)
if sm := k.get('smooth') is not None: if sm := k.get('smooth') is not None:
d['smooth'] = d.last.apply(lambda x: smooth(x, sm)) d['smooth'] = d.last.apply(lambda x: smooth(x, sm))
......
from .params import SingleModelPipelineParams, PipelineRecord from .params import PipelineParams, PipelineRecord
from ..session import session from ..session import session
def call_pipeline(func, p: SingleModelPipelineParams): def call_pipeline(func, p: PipelineParams):
acc_in = session.get_accessor(p.accessor_id, pop=True)
accessors_in = {}
for k, v in p.dict().items():
if k.endswith('accessor_id'):
accessors_in[k.split('_id')[0]] = session.get_accessor(v, pop=True)
models = {}
for k, v in p.dict().items():
if k.endswith('model_id'):
models[k.split('_id')[0]] = session.models[v]['object']
steps = func( steps = func(
acc_in, accessors_in,
session.models[p.model_id]['object'], models,
**p.dict(), **p.dict(),
) )
......
...@@ -127,12 +127,33 @@ class TestApiFromAutomatedClient(TestServerTestCase): ...@@ -127,12 +127,33 @@ class TestApiFromAutomatedClient(TestServerTestCase):
def test_pipeline_errors_when_ids_not_found(self): def test_pipeline_errors_when_ids_not_found(self):
model_id = 'not_a_real_model' self.copy_input_file_to_server()
resp = self._put( model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id']
f'pipelines/segment', in_acc_id = self._put(
body={'model_id': model_id, 'accessor_id': 'fake'} f'accessors/read_from_file',
query={
'filename': czifile['name'],
},
).json()
# respond with 409 for invalid accessor_id
self.assertEqual(
self._put(
f'pipelines/segment',
body={'model_id': model_id, 'accessor_id': 'fake'}
).status_code,
409
) )
self.assertEqual(resp.status_code, 409, resp.content.decode())
# respond with 409 for invalid model_id
self.assertEqual(
self._put(
f'pipelines/segment',
body={'model_id': 'fake', 'accessor_id': in_acc_id}
).status_code,
409
)
def test_i2i_dummy_inference_by_api(self): def test_i2i_dummy_inference_by_api(self):
self.copy_input_file_to_server() self.copy_input_file_to_server()
......
...@@ -3,12 +3,14 @@ import unittest ...@@ -3,12 +3,14 @@ import unittest
from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file
from model_server.base.pipelines import segment from model_server.base.pipelines import segment
import model_server.conf.testing as conf import model_server.conf.testing as conf
from tests.base.test_model import DummySemanticSegmentationModel from tests.base.test_model import DummySemanticSegmentationModel
czifile = conf.meta['image_files']['czifile'] czifile = conf.meta['image_files']['czifile']
output_path = conf.meta['output_path'] output_path = conf.meta['output_path']
class TestSegmentationPipeline(unittest.TestCase): class TestSegmentationPipeline(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.model = DummySemanticSegmentationModel() self.model = DummySemanticSegmentationModel()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment