-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
segment.py 1.17 KiB
from fastapi import APIRouter
from .util import call_pipeline
from ..accessors import GenericImageDataAccessor
from ..models import SemanticSegmentationModel
from .params import SingleModelPipelineParams, PipelineRecord
from ..process import smooth
from ..util import PipelineTrace
from pydantic import Field
router = APIRouter(
prefix='/pipelines',
tags=['pipelines'],
)
class SegmentParams(SingleModelPipelineParams):
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
@router.put('/segment')
def segment(p: SegmentParams) -> PipelineRecord:
"""
Run a semantic segmentation model to compute a binary mask from an input image
"""
return call_pipeline(segment_pipeline, p)
def segment_pipeline(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace:
d = PipelineTrace()
d['input'] = acc_in
if ch := k.get('channel') is not None:
d['mono'] = acc_in.get_mono(ch)
d['inference'] = model.label_pixel_class(d.last)
if sm := k.get('smooth') is not None:
d['smooth'] = d.last.apply(lambda x: smooth(x, sm))
d['output'] = d.last
return d