-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
segment.py 1.54 KiB
from typing import Dict
from .shared import call_pipeline, IncompatibleModelsError, PipelineTrace, PipelineParams, PipelineRecord
from ..accessors import GenericImageDataAccessor
from ..models import Model, SemanticSegmentationModel
from ..process import smooth
from .router import router
from pydantic import Field
class SegmentParams(PipelineParams):
accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
class SegmentRecord(PipelineRecord):
pass
@router.put('/segment')
def segment(p: SegmentParams) -> SegmentRecord:
"""
Run a semantic segmentation model to compute a binary mask from an input image
"""
return call_pipeline(segment_pipeline, p)
def segment_pipeline(
accessors: Dict[str, GenericImageDataAccessor],
models: Dict[str, Model],
**k
) -> PipelineTrace:
d = PipelineTrace(accessors.get('accessor'))
model = models.get('model')
if not isinstance(model, SemanticSegmentationModel):
raise IncompatibleModelsError('Expecting a semantic segmentation model')
if ch := k.get('channel') is not None:
d['mono'] = d['input'].get_mono(ch)
d['inference'] = model.label_pixel_class(d.last)
if sm := k.get('smooth') is not None:
d['smooth'] = d.last.apply(lambda x: smooth(x, sm))
d['output'] = d.last
return d