Skip to content
Snippets Groups Projects
segment.py 1.17 KiB
from fastapi import APIRouter

from .util import call_pipeline
from ..accessors import GenericImageDataAccessor
from ..models import SemanticSegmentationModel
from .params import SingleModelPipelineParams, PipelineRecord
from ..process import smooth
from ..util import PipelineTrace

from pydantic import Field

router = APIRouter(
    prefix='/pipelines',
    tags=['pipelines'],
)

class SegmentParams(SingleModelPipelineParams):
    channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')


@router.put('/segment')
def segment(p: SegmentParams) -> PipelineRecord:
    """
    Run a semantic segmentation model to compute a binary mask from an input image
    """
    return call_pipeline(segment_pipeline, p)

def segment_pipeline(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace:
    d = PipelineTrace()
    d['input'] = acc_in
    if ch := k.get('channel') is not None:
        d['mono'] = acc_in.get_mono(ch)
    d['inference'] = model.label_pixel_class(d.last)
    if sm := k.get('smooth') is not None:
        d['smooth'] = d.last.apply(lambda x: smooth(x, sm))
    d['output'] = d.last
    return d