from typing import Dict from .shared import call_pipeline, IncompatibleModelsError, PipelineTrace, PipelineParams, PipelineRecord from ..accessors import GenericImageDataAccessor from ..models import Model, SemanticSegmentationModel from ..process import smooth from .router import router from pydantic import Field class SegmentParams(PipelineParams): accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input') model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)') channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.') class SegmentRecord(PipelineRecord): pass # TODO: simple threshold seg model (issue0045) @router.put('/segment') def segment(p: SegmentParams) -> SegmentRecord: """ Run a semantic segmentation model to compute a binary mask from an input image """ return call_pipeline(segment_pipeline, p) def segment_pipeline( accessors: Dict[str, GenericImageDataAccessor], models: Dict[str, Model], **k ) -> PipelineTrace: d = PipelineTrace(accessors.get('accessor')) model = models.get('model') if not isinstance(model, SemanticSegmentationModel): raise IncompatibleModelsError('Expecting a pixel classification model') if ch := k.get('channel') is not None: d['mono'] = d['input'].get_mono(ch) d['inference'] = model.label_pixel_class(d.last) if sm := k.get('smooth') is not None: d['smooth'] = d.last.apply(lambda x: smooth(x, sm)) d['output'] = d.last return d