Skip to content
Snippets Groups Projects
Commit 6735a2f7 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Parameters auto-validates anything ending in accessor_id and model_id;...

Parameters auto-validates anything ending in accessor_id and model_id; call_pipeline() maps these to objects and sends to pipeline function as dicts of respective objects
parent 0c1677ec
No related branches found
No related tags found
No related merge requests found
from typing import List, Union
from fastapi import HTTPException
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator
from ..session import session, AccessorIdError
class SingleModelPipelineParams(BaseModel):
class PipelineParams(BaseModel):
accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
@validator('model_id')
def models_are_loaded(cls, model_id):
if model_id not in session.describe_loaded_models().keys():
raise HTTPException(status_code=409, detail=f'Model with ID {model_id} has not been loaded')
return model_id
@validator('accessor_id')
def accessors_are_loaded(cls, accessor_id):
try:
info = session.get_accessor_info(accessor_id)
except AccessorIdError as e:
raise HTTPException(status_code=409, detail=str(e))
if not info['loaded']:
raise HTTPException(status_code=409, detail=f'Accessor with ID {accessor_id} is not loaded')
return accessor_id
@root_validator(pre=False)
def models_are_loaded(cls, dd):
for k, v in dd.items():
if k.endswith('model_id') and v not in session.describe_loaded_models().keys():
raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded')
return dd
@root_validator(pre=False)
def accessors_are_loaded(cls, dd):
for k, v in dd.items():
if k.endswith('accessor_id'):
try:
info = session.get_accessor_info(v)
except AccessorIdError as e:
raise HTTPException(status_code=409, detail=str(e))
if not info['loaded']:
raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded')
return dd
class PipelineRecord(BaseModel):
......
from typing import Dict
from fastapi import APIRouter
from .util import call_pipeline
from ..accessors import GenericImageDataAccessor
from ..models import SemanticSegmentationModel
from .params import SingleModelPipelineParams, PipelineRecord
from ..models import Model
from .params import PipelineParams, PipelineRecord
from ..process import smooth
from ..util import PipelineTrace
......@@ -14,7 +16,9 @@ router = APIRouter(
tags=['pipelines'],
)
class SegmentParams(SingleModelPipelineParams):
class SegmentParams(PipelineParams):
accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
......@@ -25,11 +29,17 @@ def segment(p: SegmentParams) -> PipelineRecord:
"""
return call_pipeline(segment_pipeline, p)
def segment_pipeline(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace:
def segment_pipeline(
accessors: Dict[str, GenericImageDataAccessor],
models: Dict[str, Model],
**k
) -> PipelineTrace:
input_accessor = accessors.get('accessor')
model = models.get('model')
d = PipelineTrace()
d['input'] = acc_in
d['input'] = input_accessor
if ch := k.get('channel') is not None:
d['mono'] = acc_in.get_mono(ch)
d['mono'] = input_accessor.get_mono(ch)
d['inference'] = model.label_pixel_class(d.last)
if sm := k.get('smooth') is not None:
d['smooth'] = d.last.apply(lambda x: smooth(x, sm))
......
from .params import SingleModelPipelineParams, PipelineRecord
from .params import PipelineParams, PipelineRecord
from ..session import session
def call_pipeline(func, p: SingleModelPipelineParams):
acc_in = session.get_accessor(p.accessor_id, pop=True)
def call_pipeline(func, p: PipelineParams):
accessors_in = {}
for k, v in p.dict().items():
if k.endswith('accessor_id'):
accessors_in[k.split('_id')[0]] = session.get_accessor(v, pop=True)
models = {}
for k, v in p.dict().items():
if k.endswith('model_id'):
models[k.split('_id')[0]] = session.models[v]['object']
steps = func(
acc_in,
session.models[p.model_id]['object'],
accessors_in,
models,
**p.dict(),
)
......
......@@ -127,12 +127,33 @@ class TestApiFromAutomatedClient(TestServerTestCase):
def test_pipeline_errors_when_ids_not_found(self):
model_id = 'not_a_real_model'
resp = self._put(
f'pipelines/segment',
body={'model_id': model_id, 'accessor_id': 'fake'}
self.copy_input_file_to_server()
model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id']
in_acc_id = self._put(
f'accessors/read_from_file',
query={
'filename': czifile['name'],
},
).json()
# respond with 409 for invalid accessor_id
self.assertEqual(
self._put(
f'pipelines/segment',
body={'model_id': model_id, 'accessor_id': 'fake'}
).status_code,
409
)
self.assertEqual(resp.status_code, 409, resp.content.decode())
# respond with 409 for invalid model_id
self.assertEqual(
self._put(
f'pipelines/segment',
body={'model_id': 'fake', 'accessor_id': in_acc_id}
).status_code,
409
)
def test_i2i_dummy_inference_by_api(self):
self.copy_input_file_to_server()
......
......@@ -3,12 +3,14 @@ import unittest
from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file
from model_server.base.pipelines import segment
import model_server.conf.testing as conf
from tests.base.test_model import DummySemanticSegmentationModel
czifile = conf.meta['image_files']['czifile']
output_path = conf.meta['output_path']
class TestSegmentationPipeline(unittest.TestCase):
def setUp(self) -> None:
self.model = DummySemanticSegmentationModel()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment