Skip to content
Snippets Groups Projects
Commit 40db6f45 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Reorganized pipeline utility method

parent 5b5d39b9
No related merge requests found
......@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, validator
from ..session import session, AccessorIdError
class SingleModelPipelineInputParams(BaseModel):
class SingleModelPipelineParams(BaseModel):
accessor_id: str = Field(description='ID(s) of previously loaded accessor(s) to use as pipeline input')
model_id: str = Field(description='ID(s) of previously loaded segmentation model(s)')
......@@ -25,7 +25,7 @@ class SingleModelPipelineInputParams(BaseModel):
return accessor_id
class PipelineOutputParams(BaseModel):
class PipelineRecord(BaseModel):
output_accessor_id: str
model_id: str
success: bool
......
from fastapi import APIRouter
from .util import call_pipeline
from ..accessors import GenericImageDataAccessor
from ..models import SemanticSegmentationModel
from .params import SingleModelPipelineInputParams, PipelineOutputParams
from .params import SingleModelPipelineParams, PipelineRecord
from ..process import smooth
from ..util import PipelineTrace
from ..session import session
from pydantic import Field
......@@ -14,39 +14,17 @@ router = APIRouter(
tags=['pipelines'],
)
class SegmentParams(SingleModelPipelineInputParams):
class SegmentParams(SingleModelPipelineParams):
channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
class SegmentRecord(PipelineOutputParams):
pass
def _call_pipeline(func, p: SingleModelPipelineInputParams):
acc_in = session.get_accessor(p.accessor_id, pop=True)
steps = func(
acc_in,
session.models[p.model_id]['object'],
**p.dict(),
)
session.log_info(f'Completed {func.__name__} on {p.accessor_id}.')
return SegmentRecord(
output_accessor_id=session.add_accessor(steps.last),
model_id=p.model_id,
success=True,
timer=steps.times
)
# TODO: handle registration of intermediate accessors
@router.put('/segment')
def segment(p: SegmentParams) -> SegmentRecord:
def segment(p: SegmentParams) -> PipelineRecord:
"""
Run a semantic segmentation model to compute a binary mask from an input image
"""
return _call_pipeline(segment_pipeline, p)
return call_pipeline(segment_pipeline, p)
def segment_pipeline(acc_in: GenericImageDataAccessor, model: SemanticSegmentationModel, **k) -> PipelineTrace:
d = PipelineTrace()
......
from .params import SingleModelPipelineParams, PipelineRecord
from ..session import session
def call_pipeline(func, p: SingleModelPipelineParams):
acc_in = session.get_accessor(p.accessor_id, pop=True)
steps = func(
acc_in,
session.models[p.model_id]['object'],
**p.dict(),
)
session.log_info(f'Completed {func.__name__} on {p.accessor_id}.')
return PipelineRecord(
output_accessor_id=session.add_accessor(steps.last),
model_id=p.model_id,
success=True,
timer=steps.times
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment