Skip to content
Snippets Groups Projects
Commit a197934e authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Segmentation pipeline works, module is annotated

parent 2f644dfc
No related branches found
No related tags found
No related merge requests found
......@@ -6,17 +6,17 @@ from ..process import smooth
from ..util import PipelineTrace
from ..session import session
from pydantic import BaseModel, validator
from pydantic import BaseModel, Field, validator
router = APIRouter(
prefix='/pipelines',
tags=['pipelines'],
)
class Params(BaseModel):
class SegmentParams(BaseModel):
model_id: str
input_filename: str
channel: int = None
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
@validator('model_id')
def model_is_loaded(cls, v):
......@@ -31,18 +31,22 @@ class Params(BaseModel):
return v
class Record(BaseModel):
class SegmentRecord(BaseModel):
model_id: str
input_filepath: str
output_filepath: str
success: bool
timer: PipelineTrace
timer: dict
@router.put('/segment')
def file_call(p: Params) -> Record:
def segment(p: SegmentParams) -> SegmentRecord:
"""
Run a semantic segmentation model to compute a binary mask from an input image
"""
inpath = session.paths['inbound_images'] / p.input_filename
steps = classify_pixels(
steps = pipeline(
generate_file_accessor(inpath),
session.models[p.model_id]['object'],
**p,
......@@ -51,7 +55,7 @@ def file_call(p: Params) -> Record:
write_accessor_data_to_file(outpath, steps.last)
session.log_info(f'Completed segmentation of {p.input_filename}')
return Record(
return SegmentRecord(
model_id=p.model_id,
input_filepath=inpath.__str__(),
output_filepath=outpath.__str__(),
......@@ -59,16 +63,7 @@ def file_call(p: Params) -> Record:
timer=steps.times
)
def classify_pixels(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace:
"""
Run a semantic segmentation model to compute a binary mask from an input image
:param fpi: Path object that references input image file
:param model: semantic segmentation model instance
:param where_output: Path object that references output image directory
:param kwargs: variable-length keyword arguments
:return: record object
"""
def pipeline(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace:
d = PipelineTrace()
d['input'] = acc_in
if ch := k.get('channel'):
......
import unittest
from model_server.base.accessors import generate_file_accessor, write_accessor_data_to_file
from model_server.base.pipelines.segment import classify_pixels, Params
from model_server.base.pipelines import segment
import model_server.conf.testing as conf
from tests.base.test_model import DummySemanticSegmentationModel
......@@ -14,7 +14,7 @@ class TestGetSessionObject(unittest.TestCase):
def test_single_session_instance(self):
acc = generate_file_accessor(czifile['path'])
trace = classify_pixels(acc, self.model, channel=2, smooth=3)
trace = segment.pipeline(acc, self.model, channel=2, smooth=3)
outfp = output_path / 'classify_pixels.tif'
write_accessor_data_to_file(outfp, trace.last)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment