diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py index 3e0f76e64dbbb86f8812e6de9d040d91f09a1abf..fc9c1be647a0bd0dfca9164819a9f7674b2671a1 100644 --- a/model_server/base/pipelines/segment.py +++ b/model_server/base/pipelines/segment.py @@ -1,11 +1,10 @@ -from pathlib import Path - from fastapi import APIRouter, HTTPException -from model_server.base.accessors import GenericImageDataAccessor, generate_file_accessor, write_accessor_data_to_file -from model_server.base.models import SemanticSegmentationModel -from model_server.base.util import PipelineTrace -from model_server.base.session import session +from ..accessors import GenericImageDataAccessor, generate_file_accessor, write_accessor_data_to_file +from ..models import SemanticSegmentationModel +from ..process import smooth +from ..util import PipelineTrace +from ..session import session from pydantic import BaseModel, validator @@ -14,7 +13,7 @@ router = APIRouter( tags=['pipelines'], ) -class FileCallParams(BaseModel): +class Params(BaseModel): model_id: str input_filename: str channel: int = None @@ -32,7 +31,7 @@ class FileCallParams(BaseModel): return v -class FileCallRecord(BaseModel): +class Record(BaseModel): model_id: str input_filepath: str output_filepath: str @@ -41,18 +40,18 @@ class FileCallRecord(BaseModel): @router.put('/segment') -def file_call(p: FileCallParams) -> FileCallRecord: +def file_call(p: Params) -> Record: inpath = session.paths['inbound_images'] / p.input_filename steps = classify_pixels( generate_file_accessor(inpath), session.models[p.model_id]['object'], - channel=p.channel, + **p, ) outpath = session.paths['outbound_images'] / (p.model_id + '_' + inpath.stem + '.tif') write_accessor_data_to_file(outpath, steps.last) session.log_info(f'Completed segmentation of {p.input_filename}') - return FileCallRecord( + return Record( model_id=p.model_id, input_filepath=inpath.__str__(), output_filepath=outpath.__str__(), @@ -60,8 +59,7 @@ def file_call(p: FileCallParams) -> FileCallRecord: timer=steps.times ) -# TODO: implement smoothing, move this away from model itself -def classify_pixels(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **kwargs) -> PipelineTrace: +def classify_pixels(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace: """ Run a semantic segmentation model to compute a binary mask from an input image @@ -72,8 +70,11 @@ def classify_pixels(acc_in: GenericImageDataAccessor, model: SemanticSegmentatio :return: record object """ d = PipelineTrace() - ch = kwargs.get('channel') - d['input'] = acc_in.get_mono(ch) + d['input'] = acc_in + if ch := k.get('channel'): + d['mono'] = acc_in.get_mono(ch) d['inference'] = model.label_pixel_class(d.last) + if sm := k.get('smooth'): + d['smooth'] = d.last.apply(lambda x: smooth(x, sm)) d['output'] = d.last return d \ No newline at end of file diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py index d89b6f57cad18278f9508df8416f7efe46b21a77..ba7b036f0d8aefe66ba213d5c15046685fa9f03a 100644 --- a/tests/base/test_pipelines.py +++ b/tests/base/test_pipelines.py @@ -1,7 +1,7 @@ import unittest from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file -from model_server.base.pipelines.segment import classify_pixels +from model_server.base.pipelines.segment import classify_pixels, Params import model_server.conf.testing as conf from tests.base.test_model import DummySemanticSegmentationModel @@ -14,7 +14,7 @@ class TestGetSessionObject(unittest.TestCase): def test_single_session_instance(self): acc = generate_file_accessor(czifile['path']) - trace = classify_pixels(acc, self.model, channel=2) + trace = classify_pixels(acc, self.model, channel=2, smooth=3) outfp = output_path / 'classify_pixels.tif' write_accessor_data_to_file(outfp, trace.last)