From 2a3f7babc3acb5e6b8721f71ca122ab43c6ba0a5 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 9 Nov 2024 07:00:05 +0100
Subject: [PATCH] Tasks runs, just need to hook up result with pipeline test

---
 model_server/base/api.py           | 11 ++++++++---
 model_server/base/session.py       |  5 ++++-
 tests/base/test_roiset_pipeline.py | 18 +++++++++++++-----
 3 files changed, 25 insertions(+), 9 deletions(-)

diff --git a/model_server/base/api.py b/model_server/base/api.py
index 9ef95fe6..5f6ea7f7 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -213,13 +213,18 @@ class TaskInfo(BaseModel):
     error: Union[str, None]
     result: Union[Dict, None]
 
-# TODO: cover is API tests, with dummy task resource endpoint
+# TODO: cover task API with tests, using dummy task resource endpoint
 
-@app.get('/tasks/{task_id')
+# TODO: return something smarter than bool
+@app.put('/tasks/{task_id}/run')
+def run_task(task_id: str) -> bool:
+    return session.queue.run_task(task_id)
+
+@app.get('/tasks/{task_id}')
 def get_task(task_id: str) -> TaskInfo:
     return session.queue.get_task_info(task_id)
 
 @app.get('/tasks')
 def list_tasks() -> Dict[str, TaskInfo]:
     res = session.queue.list_tasks()
-    return res
\ No newline at end of file
+    return res
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 202b33de..0021bf82 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -45,6 +45,7 @@ class Queue(object):
 
     def __init__(self):
         self._queue = OrderedDict()
+        self._handles = OrderedDict()
 
     def add_task(self, func: callable, params: dict) -> str:
         task_id = str(uuid.uuid4())
@@ -59,6 +60,8 @@ class Queue(object):
             'result': None,
         }
 
+        self._handles[task_id] = func
+
         return str(task_id)
 
     def get_task_info(self, task_id: str) -> dict:
@@ -69,7 +72,7 @@ class Queue(object):
 
     def run_task(self, task_id: str):
         task = self._queue[task_id]
-        f = task['func']
+        f = self._handles[task_id]
         p = task['params']
         try:
             task['status'] = self.status_codes['in_progress']
diff --git a/tests/base/test_roiset_pipeline.py b/tests/base/test_roiset_pipeline.py
index 0af1ebf2..fbfdc86f 100644
--- a/tests/base/test_roiset_pipeline.py
+++ b/tests/base/test_roiset_pipeline.py
@@ -262,7 +262,7 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd
 class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi):
     def _object_map_workflow(self, ob_classifer_id):
 
-        res = self.assertPutSuccess(
+        res_queue = self.assertPutSuccess(
             'pipelines/queue/roiset_to_obmap',
             body={
                 'accessor_id': self.test_load_input_accessor(),
@@ -275,8 +275,16 @@ class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi):
             }
         )
 
-        tasks = self.assertGetSuccess('tasks')
+        # check that task in enqueued
+        task_id = res_queue['task_id']
+        task_info = self.assertGetSuccess(f'tasks/{task_id}')
+        self.assertEqual(task_info['status'], 'WAITING')
+
+        # run the task
+        res_run = self.assertPutSuccess(
+            f'tasks/{task_id}/run'
+        )
+        self.assertTrue(res_run)
+        self.assertEqual(self.assertGetSuccess(f'tasks/{task_id}')['status'], 'FINISHED')
 
-        # check on enqueued task
-        task_id = res['task_id']
-        return res
\ No newline at end of file
+        return False
\ No newline at end of file
-- 
GitLab