Skip to content
Snippets Groups Projects
segment.py 1.54 KiB
from typing import Dict

from .shared import call_pipeline, IncompatibleModelsError, PipelineTrace, PipelineParams, PipelineRecord
from ..accessors import GenericImageDataAccessor
from ..models import Model, SemanticSegmentationModel
from ..process import smooth
from .router import router

from pydantic import Field


class SegmentParams(PipelineParams):
    accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
    model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
    channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')


class SegmentRecord(PipelineRecord):
    pass


@router.put('/segment')
def segment(p: SegmentParams) -> SegmentRecord:
    """
    Run a semantic segmentation model to compute a binary mask from an input image
    """
    return call_pipeline(segment_pipeline, p)

def segment_pipeline(
        accessors: Dict[str, GenericImageDataAccessor],
        models: Dict[str, Model],
        **k
) -> PipelineTrace:
    d = PipelineTrace(accessors.get('accessor'))
    model = models.get('model')

    if not isinstance(model, SemanticSegmentationModel):
        raise IncompatibleModelsError('Expecting a semantic segmentation model')

    if ch := k.get('channel') is not None:
        d['mono'] = d['input'].get_mono(ch)
    d['inference'] = model.label_pixel_class(d.last)
    if sm := k.get('smooth') is not None:
        d['smooth'] = d.last.apply(lambda x: smooth(x, sm))
    d['output'] = d.last
    return d