From 794d976cfe226271c0fdf2d894c1c5fa2dc63cd9 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 9 Nov 2024 07:43:59 +0100
Subject: [PATCH] Tests pass using task queue

---
 model_server/base/api.py              | 14 +++++++++++---
 model_server/base/pipelines/shared.py |  2 +-
 model_server/base/session.py          | 10 ++++------
 tests/base/test_roiset_pipeline.py    |  2 +-
 4 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/model_server/base/api.py b/model_server/base/api.py
index 5f6ea7f7..38705a2d 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -4,6 +4,7 @@ from typing import Dict, List, Union
 from fastapi import FastAPI, HTTPException
 from .accessors import generate_file_accessor
 from .models import BinaryThresholdSegmentationModel
+from .pipelines.shared import PipelineRecord
 from .roiset import IntensityThresholdInstanceMaskSegmentationModel, RoiSetExportParams, SerializeRoiSetError
 from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError
 
@@ -130,7 +131,7 @@ def delete_accessor(accessor_id: str):
     else:
         return _session_accessor(session.del_accessor, accessor_id)
 
-
+# TODO: optional lazy loading, so that batch task can be queued before file data is loaded
 @app.put('/accessors/read_from_file/{filename}')
 def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None):
     fp = session.paths['inbound_images'] / filename
@@ -215,9 +216,10 @@ class TaskInfo(BaseModel):
 
 # TODO: cover task API with tests, using dummy task resource endpoint
 
-# TODO: return something smarter than bool
+# TODO: test to cover reporting of exception in running tasks
+
 @app.put('/tasks/{task_id}/run')
-def run_task(task_id: str) -> bool:
+def run_task(task_id: str) -> PipelineRecord:
     return session.queue.run_task(task_id)
 
 @app.get('/tasks/{task_id}')
@@ -228,3 +230,9 @@ def get_task(task_id: str) -> TaskInfo:
 def list_tasks() -> Dict[str, TaskInfo]:
     res = session.queue.list_tasks()
     return res
+
+# TODO: implement
+@app.put('/tasks/run/on_files')
+def task_file_batch():
+    # new callable that parameterizes file name to acc_id, then passes this to to /tasks/*/run
+    pass
\ No newline at end of file
diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py
index 73b797a6..1e3e68d5 100644
--- a/model_server/base/pipelines/shared.py
+++ b/model_server/base/pipelines/shared.py
@@ -47,7 +47,7 @@ class PipelineRecord(BaseModel):
 class PipelineQueueRecord(BaseModel):
     task_id: str
 
-
+# TODO: variant of this for queued tasks where accessor ID is separately parameterized
 def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
     # match accessor IDs to loaded accessor objects
     accessors_in = {}
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 0021bf82..1cef23b1 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -76,16 +76,14 @@ class Queue(object):
         p = task['params']
         try:
             task['status'] = self.status_codes['in_progress']
-            task['result'] = f(p)
+            result = f(p)
             task['status'] = self.status_codes['finished']
-            return True
+            task['result'] = result
+            return result
         except Exception as e:
             task['status'] = self.status_codes['failed']
             task['error'] = e
-            return False
-
-
-
+            raise e
 
 class _Session(object):
     """
diff --git a/tests/base/test_roiset_pipeline.py b/tests/base/test_roiset_pipeline.py
index fbfdc86f..5cb78e3e 100644
--- a/tests/base/test_roiset_pipeline.py
+++ b/tests/base/test_roiset_pipeline.py
@@ -287,4 +287,4 @@ class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi):
         self.assertTrue(res_run)
         self.assertEqual(self.assertGetSuccess(f'tasks/{task_id}')['status'], 'FINISHED')
 
-        return False
\ No newline at end of file
+        return res_run
\ No newline at end of file
-- 
GitLab