diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py index fc9c1be647a0bd0dfca9164819a9f7674b2671a1..192a9c9c4bf6184670b38cf5c8efea4df28282a8 100644 --- a/model_server/base/pipelines/segment.py +++ b/model_server/base/pipelines/segment.py @@ -6,17 +6,17 @@ from ..process import smooth from ..util import PipelineTrace from ..session import session -from pydantic import BaseModel, validator +from pydantic import BaseModel, Field, validator router = APIRouter( prefix='/pipelines', tags=['pipelines'], ) -class Params(BaseModel): +class SegmentParams(BaseModel): model_id: str input_filename: str - channel: int = None + channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.') @validator('model_id') def model_is_loaded(cls, v): @@ -31,18 +31,22 @@ class Params(BaseModel): return v -class Record(BaseModel): +class SegmentRecord(BaseModel): model_id: str input_filepath: str output_filepath: str success: bool - timer: PipelineTrace + timer: dict @router.put('/segment') -def file_call(p: Params) -> Record: +def segment(p: SegmentParams) -> SegmentRecord: + """ + Run a semantic segmentation model to compute a binary mask from an input image + """ + inpath = session.paths['inbound_images'] / p.input_filename - steps = classify_pixels( + steps = pipeline( generate_file_accessor(inpath), session.models[p.model_id]['object'], **p, @@ -51,7 +55,7 @@ def file_call(p: Params) -> Record: write_accessor_data_to_file(outpath, steps.last) session.log_info(f'Completed segmentation of {p.input_filename}') - return Record( + return SegmentRecord( model_id=p.model_id, input_filepath=inpath.__str__(), output_filepath=outpath.__str__(), @@ -59,16 +63,7 @@ def file_call(p: Params) -> Record: timer=steps.times ) -def classify_pixels(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace: - """ - Run a semantic segmentation model to compute a binary mask from an input image - - :param fpi: Path object that references input image file - :param model: semantic segmentation model instance - :param where_output: Path object that references output image directory - :param kwargs: variable-length keyword arguments - :return: record object - """ +def pipeline(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace: d = PipelineTrace() d['input'] = acc_in if ch := k.get('channel'): diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py index ba7b036f0d8aefe66ba213d5c15046685fa9f03a..774b789f0ea25ec2ca3a7a775cd7ed9a3935878e 100644 --- a/tests/base/test_pipelines.py +++ b/tests/base/test_pipelines.py @@ -1,7 +1,7 @@ import unittest from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file -from model_server.base.pipelines.segment import classify_pixels, Params +from model_server.base.pipelines import segment import model_server.conf.testing as conf from tests.base.test_model import DummySemanticSegmentationModel @@ -14,7 +14,7 @@ class TestGetSessionObject(unittest.TestCase): def test_single_session_instance(self): acc = generate_file_accessor(czifile['path']) - trace = classify_pixels(acc, self.model, channel=2, smooth=3) + trace = segment.pipeline(acc, self.model, channel=2, smooth=3) outfp = output_path / 'classify_pixels.tif' write_accessor_data_to_file(outfp, trace.last)