Skip to content
Snippets Groups Projects
Commit 82d4eaf7 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Option to schedule pipeline as a task

parent 794d976c
No related branches found
No related tags found
2 merge requests!102Merge staging as release,!72Pipeline task management
This commit is part of merge request !72. Comments created here will be created in the context of that merge request.
......@@ -58,20 +58,14 @@ class RoiSetObjectMapParams(PipelineParams):
class RoiSetToObjectMapRecord(PipelineRecord):
pass
# TODO: shorten endpoint name, maybe simplify return typespec
@router.put('/roiset_to_obmap/infer')
def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
def roiset_object_map(p: RoiSetObjectMapParams) -> Union[RoiSetToObjectMapRecord, PipelineQueueRecord]:
"""
Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification.
"""
return call_pipeline(roiset_object_map_pipeline, p)
@router.put('/queue/roiset_to_obmap')
def queue_roiset_object_map(p: RoiSetObjectMapParams) -> PipelineQueueRecord:
task_id = session.queue.add_task(
lambda x: call_pipeline(roiset_object_map_pipeline, x),
p
)
return {'task_id': task_id}
def roiset_object_map_pipeline(
accessors: Dict[str, GenericImageDataAccessor],
......
......@@ -4,6 +4,7 @@ from time import perf_counter
from typing import List, Union
from fastapi import HTTPException
from numba.scripts.generate_lower_listing import description
from pydantic import BaseModel, Field, root_validator
from ..accessors import GenericImageDataAccessor, InMemoryDataAccessor
......@@ -12,6 +13,7 @@ from ..session import session, AccessorIdError
class PipelineParams(BaseModel):
schedule: bool = Field(False, description='Schedule as a task instead of running immediately')
keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True')
......@@ -47,8 +49,16 @@ class PipelineRecord(BaseModel):
class PipelineQueueRecord(BaseModel):
task_id: str
# TODO: variant of this for queued tasks where accessor ID is separately parameterized
def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
def call_pipeline(func, p: PipelineParams) -> Union[PipelineRecord, PipelineQueueRecord]:
# instead of running right away, schedule pipeline as a task
if p.schedule:
p.schedule = False
task_id = session.queue.add_task(
lambda x: call_pipeline(func, x),
p
)
return PipelineQueueRecord(task_id=task_id)
# match accessor IDs to loaded accessor objects
accessors_in = {}
......
......@@ -263,8 +263,9 @@ class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi):
def _object_map_workflow(self, ob_classifer_id):
res_queue = self.assertPutSuccess(
'pipelines/queue/roiset_to_obmap',
'pipelines/roiset_to_obmap/infer',
body={
'schedule': True,
'accessor_id': self.test_load_input_accessor(),
'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
'object_classifier_model_id': ob_classifer_id,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment