From 794d976cfe226271c0fdf2d894c1c5fa2dc63cd9 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 9 Nov 2024 07:43:59 +0100 Subject: [PATCH] Tests pass using task queue --- model_server/base/api.py | 14 +++++++++++--- model_server/base/pipelines/shared.py | 2 +- model_server/base/session.py | 10 ++++------ tests/base/test_roiset_pipeline.py | 2 +- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/model_server/base/api.py b/model_server/base/api.py index 5f6ea7f7..38705a2d 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -4,6 +4,7 @@ from typing import Dict, List, Union from fastapi import FastAPI, HTTPException from .accessors import generate_file_accessor from .models import BinaryThresholdSegmentationModel +from .pipelines.shared import PipelineRecord from .roiset import IntensityThresholdInstanceMaskSegmentationModel, RoiSetExportParams, SerializeRoiSetError from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError @@ -130,7 +131,7 @@ def delete_accessor(accessor_id: str): else: return _session_accessor(session.del_accessor, accessor_id) - +# TODO: optional lazy loading, so that batch task can be queued before file data is loaded @app.put('/accessors/read_from_file/{filename}') def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None): fp = session.paths['inbound_images'] / filename @@ -215,9 +216,10 @@ class TaskInfo(BaseModel): # TODO: cover task API with tests, using dummy task resource endpoint -# TODO: return something smarter than bool +# TODO: test to cover reporting of exception in running tasks + @app.put('/tasks/{task_id}/run') -def run_task(task_id: str) -> bool: +def run_task(task_id: str) -> PipelineRecord: return session.queue.run_task(task_id) @app.get('/tasks/{task_id}') @@ -228,3 +230,9 @@ def get_task(task_id: str) -> TaskInfo: def list_tasks() -> Dict[str, TaskInfo]: res = session.queue.list_tasks() return res + +# TODO: implement +@app.put('/tasks/run/on_files') +def task_file_batch(): + # new callable that parameterizes file name to acc_id, then passes this to to /tasks/*/run + pass \ No newline at end of file diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py index 73b797a6..1e3e68d5 100644 --- a/model_server/base/pipelines/shared.py +++ b/model_server/base/pipelines/shared.py @@ -47,7 +47,7 @@ class PipelineRecord(BaseModel): class PipelineQueueRecord(BaseModel): task_id: str - +# TODO: variant of this for queued tasks where accessor ID is separately parameterized def call_pipeline(func, p: PipelineParams) -> PipelineRecord: # match accessor IDs to loaded accessor objects accessors_in = {} diff --git a/model_server/base/session.py b/model_server/base/session.py index 0021bf82..1cef23b1 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -76,16 +76,14 @@ class Queue(object): p = task['params'] try: task['status'] = self.status_codes['in_progress'] - task['result'] = f(p) + result = f(p) task['status'] = self.status_codes['finished'] - return True + task['result'] = result + return result except Exception as e: task['status'] = self.status_codes['failed'] task['error'] = e - return False - - - + raise e class _Session(object): """ diff --git a/tests/base/test_roiset_pipeline.py b/tests/base/test_roiset_pipeline.py index fbfdc86f..5cb78e3e 100644 --- a/tests/base/test_roiset_pipeline.py +++ b/tests/base/test_roiset_pipeline.py @@ -287,4 +287,4 @@ class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi): self.assertTrue(res_run) self.assertEqual(self.assertGetSuccess(f'tasks/{task_id}')['status'], 'FINISHED') - return False \ No newline at end of file + return res_run \ No newline at end of file -- GitLab