diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index c7163b80a5ace924d309e380679601b738052f96..459ed97526193bab7d040e899280f7bb54355dcd 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -58,20 +58,14 @@ class RoiSetObjectMapParams(PipelineParams): class RoiSetToObjectMapRecord(PipelineRecord): pass +# TODO: shorten endpoint name, maybe simplify return typespec @router.put('/roiset_to_obmap/infer') -def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: +def roiset_object_map(p: RoiSetObjectMapParams) -> Union[RoiSetToObjectMapRecord, PipelineQueueRecord]: """ Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification. """ return call_pipeline(roiset_object_map_pipeline, p) -@router.put('/queue/roiset_to_obmap') -def queue_roiset_object_map(p: RoiSetObjectMapParams) -> PipelineQueueRecord: - task_id = session.queue.add_task( - lambda x: call_pipeline(roiset_object_map_pipeline, x), - p - ) - return {'task_id': task_id} def roiset_object_map_pipeline( accessors: Dict[str, GenericImageDataAccessor], diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py index 1e3e68d5389285e7ee059614ce84401f0dd707fe..1a1409108831e26e893a9db027469754c7177dc0 100644 --- a/model_server/base/pipelines/shared.py +++ b/model_server/base/pipelines/shared.py @@ -4,6 +4,7 @@ from time import perf_counter from typing import List, Union from fastapi import HTTPException +from numba.scripts.generate_lower_listing import description from pydantic import BaseModel, Field, root_validator from ..accessors import GenericImageDataAccessor, InMemoryDataAccessor @@ -12,6 +13,7 @@ from ..session import session, AccessorIdError class PipelineParams(BaseModel): + schedule: bool = Field(False, description='Schedule as a task instead of running immediately') keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session') api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True') @@ -47,8 +49,16 @@ class PipelineRecord(BaseModel): class PipelineQueueRecord(BaseModel): task_id: str -# TODO: variant of this for queued tasks where accessor ID is separately parameterized -def call_pipeline(func, p: PipelineParams) -> PipelineRecord: +def call_pipeline(func, p: PipelineParams) -> Union[PipelineRecord, PipelineQueueRecord]: + # instead of running right away, schedule pipeline as a task + if p.schedule: + p.schedule = False + task_id = session.queue.add_task( + lambda x: call_pipeline(func, x), + p + ) + return PipelineQueueRecord(task_id=task_id) + # match accessor IDs to loaded accessor objects accessors_in = {} diff --git a/tests/base/test_roiset_pipeline.py b/tests/base/test_roiset_pipeline.py index 5cb78e3e41866714ea61744cc10dd538959c8243..f42527722f0745536e63ad685ef893e682827138 100644 --- a/tests/base/test_roiset_pipeline.py +++ b/tests/base/test_roiset_pipeline.py @@ -263,8 +263,9 @@ class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi): def _object_map_workflow(self, ob_classifer_id): res_queue = self.assertPutSuccess( - 'pipelines/queue/roiset_to_obmap', + 'pipelines/roiset_to_obmap/infer', body={ + 'schedule': True, 'accessor_id': self.test_load_input_accessor(), 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), 'object_classifier_model_id': ob_classifer_id,