From ff6e182dc7eaeb021ebf78cac63882ffd366ff4f Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 9 Nov 2024 06:47:26 +0100
Subject: [PATCH] Implemented queue, basic endpoints, and first pipeline
 example.  Did not yet test execution.

---
 model_server/base/api.py                    | 24 ++++-
 model_server/base/pipelines/roiset_obmap.py | 12 ++-
 model_server/base/pipelines/shared.py       |  4 +
 model_server/base/session.py                | 53 +++++++++++
 tests/base/test_api.py                      |  4 +-
 tests/base/test_roiset_pipeline.py          | 99 ++++++++++++++++++++-
 6 files changed, 188 insertions(+), 8 deletions(-)

diff --git a/model_server/base/api.py b/model_server/base/api.py
index 972c909e..9ef95fe6 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -1,5 +1,5 @@
 from pydantic import BaseModel, Field
-from typing import List, Union
+from typing import Dict, List, Union
 
 from fastapi import FastAPI, HTTPException
 from .accessors import generate_file_accessor
@@ -51,6 +51,7 @@ def show_session_status():
         'paths': session.get_paths(),
         'rois': session.list_rois(),
         'accessors': session.list_accessors(),
+        'tasks': session.queue.list_tasks(),
     }
 
 
@@ -202,4 +203,23 @@ def roiset_get_object_map(
     raise HTTPException(
         404,
         f'Did not find object map from classification model {model_id} in RoiSet {roiset_id}'
-    )
\ No newline at end of file
+    )
+
+class TaskInfo(BaseModel):
+    module: str
+    params: dict
+    func_str: str
+    status: str
+    error: Union[str, None]
+    result: Union[Dict, None]
+
+# TODO: cover is API tests, with dummy task resource endpoint
+
+@app.get('/tasks/{task_id')
+def get_task(task_id: str) -> TaskInfo:
+    return session.queue.get_task_info(task_id)
+
+@app.get('/tasks')
+def list_tasks() -> Dict[str, TaskInfo]:
+    res = session.queue.list_tasks()
+    return res
\ No newline at end of file
diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py
index 8fea5a83..c7163b80 100644
--- a/model_server/base/pipelines/roiset_obmap.py
+++ b/model_server/base/pipelines/roiset_obmap.py
@@ -8,9 +8,10 @@ from .segment_zproj import segment_zproj_pipeline
 from .shared import call_pipeline
 from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams
 
-from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord
+from ..pipelines.shared import PipelineQueueRecord, PipelineTrace, PipelineParams, PipelineRecord
 
 from ..models import Model, InstanceMaskSegmentationModel
+from ..session import session
 
 
 class RoiSetObjectMapParams(PipelineParams):
@@ -32,7 +33,7 @@ class RoiSetObjectMapParams(PipelineParams):
     )
     object_classifier_model_id: Union[str, None] = Field(
         None,
-        description='Object classifier used to classify segmented objectss'
+        description='Object classifier used to classify segmented objects'
     )
     patches_channel: int = Field(
         description='Channel of input image used in patches sent to object classifier'
@@ -64,6 +65,13 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
     """
     return call_pipeline(roiset_object_map_pipeline, p)
 
+@router.put('/queue/roiset_to_obmap')
+def queue_roiset_object_map(p: RoiSetObjectMapParams) -> PipelineQueueRecord:
+    task_id = session.queue.add_task(
+        lambda x: call_pipeline(roiset_object_map_pipeline, x),
+        p
+    )
+    return {'task_id': task_id}
 
 def roiset_object_map_pipeline(
         accessors: Dict[str, GenericImageDataAccessor],
diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py
index 3fbf1d58..73b797a6 100644
--- a/model_server/base/pipelines/shared.py
+++ b/model_server/base/pipelines/shared.py
@@ -44,6 +44,10 @@ class PipelineRecord(BaseModel):
     roiset_id: Union[str, None] = None
 
 
+class PipelineQueueRecord(BaseModel):
+    task_id: str
+
+
 def call_pipeline(func, p: PipelineParams) -> PipelineRecord:
     # match accessor IDs to loaded accessor objects
     accessors_in = {}
diff --git a/model_server/base/session.py b/model_server/base/session.py
index c2c596ed..202b33de 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -1,6 +1,7 @@
 from collections import OrderedDict
 import logging
 import os
+import uuid
 
 from pathlib import Path, PureWindowsPath
 from pydantic import BaseModel
@@ -32,6 +33,57 @@ class CsvTable(object):
         self.empty = False
         return True
 
+# TODO: is there a standard library alternative here?
+class Queue(object):
+
+    status_codes = {
+        'waiting': 'WAITING',
+        'in_progress': 'IN_PROGRESS',
+        'finished': 'FINISHED',
+        'failed': 'FAILED',
+    }
+
+    def __init__(self):
+        self._queue = OrderedDict()
+
+    def add_task(self, func: callable, params: dict) -> str:
+        task_id = str(uuid.uuid4())
+        name = func.__name__
+
+        self._queue[task_id] = {
+            'module': func.__module__,
+            'params': params,
+            'func_str': str(func),
+            'status': self.status_codes['waiting'],
+            'error': None,
+            'result': None,
+        }
+
+        return str(task_id)
+
+    def get_task_info(self, task_id: str) -> dict:
+        return self._queue[task_id]
+
+    def list_tasks(self) -> OrderedDict:
+        return self._queue
+
+    def run_task(self, task_id: str):
+        task = self._queue[task_id]
+        f = task['func']
+        p = task['params']
+        try:
+            task['status'] = self.status_codes['in_progress']
+            task['result'] = f(p)
+            task['status'] = self.status_codes['finished']
+            return True
+        except Exception as e:
+            task['status'] = self.status_codes['failed']
+            task['error'] = e
+            return False
+
+
+
+
 class _Session(object):
     """
     Singleton class for a server session that persists data between API calls
@@ -43,6 +95,7 @@ class _Session(object):
         self.models = {}  # model_id : model object
         self.paths = self.make_paths()
         self.accessors = OrderedDict()
+        self.queue = Queue()
         self.rois = OrderedDict()
 
         self.logfile = self.paths['logs'] / f'session.log'
diff --git a/tests/base/test_api.py b/tests/base/test_api.py
index a5df9e2f..8e2ba825 100644
--- a/tests/base/test_api.py
+++ b/tests/base/test_api.py
@@ -197,6 +197,4 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         self.assertPutSuccess(
             '/models/classify/threshold/load',
             body={}
-        )
-
-
+        )
\ No newline at end of file
diff --git a/tests/base/test_roiset_pipeline.py b/tests/base/test_roiset_pipeline.py
index a32ec47b..0af1ebf2 100644
--- a/tests/base/test_roiset_pipeline.py
+++ b/tests/base/test_roiset_pipeline.py
@@ -1,3 +1,4 @@
+import json
 from pathlib import Path
 import unittest
 
@@ -106,7 +107,6 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertEqual(rois.count, 22)
         self.assertEqual(len(trace['ob_id'].unique()[0]), 2)
 
-
 class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
 
     input_data = data['multichannel_zstack_raw']
@@ -183,3 +183,100 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd
         res = self._object_map_workflow(None)
         acc_obmap = self.get_accessor(res['output_accessor_id'])
         self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
+
+class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
+
+    input_data = data['multichannel_zstack_raw']
+
+
+    def setUp(self) -> None:
+        self.where_out = output_path / 'roiset'
+        self.where_out.mkdir(parents=True, exist_ok=True)
+        return conf.TestServerBaseClass.setUp(self)
+
+    def test_load_input_accessor(self):
+        fname = self.copy_input_file_to_server()
+        return self.assertPutSuccess(f'accessors/read_from_file/{fname}')
+
+    def test_load_pixel_classifier(self):
+        mid = self.assertPutSuccess(
+            'models/seg/threshold/load/',
+            query={'tr': 0.2},
+        )['model_id']
+        self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel'))
+        return mid
+
+    def test_load_object_classifier(self):
+        mid = self.assertPutSuccess(
+            'models/classify/threshold/load/',
+            body={'tr': 0}
+        )['model_id']
+        self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation'))
+        return mid
+
+    def _object_map_workflow(self, ob_classifer_id):
+        res = self.assertPutSuccess(
+            'pipelines/roiset_to_obmap/infer',
+            body={
+                'accessor_id': self.test_load_input_accessor(),
+                'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
+                'object_classifier_model_id': ob_classifer_id,
+                'segmentation': {'channel': 0},
+                'patches_channel': 1,
+                'roi_params': self._get_roi_params(),
+                'export_params': self._get_export_params(),
+            },
+        )
+
+        # check on automatically written RoiSet
+        roiset_id = res['roiset_id']
+        roiset_info = self.assertGetSuccess(f'rois/{roiset_id}')
+        self.assertGreater(roiset_info['count'], 0)
+        return res
+
+    def test_workflow_with_object_classifier(self):
+        obmod_id = self.test_load_object_classifier()
+        res = self._object_map_workflow(obmod_id)
+        acc_obmap = self.get_accessor(res['output_accessor_id'])
+        self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
+
+        # get object map via RoiSet API
+        roiset_id = res['roiset_id']
+        obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True})
+        acc_obmap_roiset = self.get_accessor(obmap_id)
+        self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data))
+
+        # check serialize RoiSet
+        self.assertPutSuccess(f'rois/write/{roiset_id}')
+        self.assertFalse(
+            self.assertGetSuccess(f'rois/{roiset_id}')['loaded']
+        )
+
+
+    def test_workflow_without_object_classifier(self):
+        res = self._object_map_workflow(None)
+        acc_obmap = self.get_accessor(res['output_accessor_id'])
+        self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
+
+
+class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi):
+    def _object_map_workflow(self, ob_classifer_id):
+
+        res = self.assertPutSuccess(
+            'pipelines/queue/roiset_to_obmap',
+            body={
+                'accessor_id': self.test_load_input_accessor(),
+                'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
+                'object_classifier_model_id': ob_classifer_id,
+                'segmentation': {'channel': 0},
+                'patches_channel': 1,
+                'roi_params': self._get_roi_params(),
+                'export_params': self._get_export_params(),
+            }
+        )
+
+        tasks = self.assertGetSuccess('tasks')
+
+        # check on enqueued task
+        task_id = res['task_id']
+        return res
\ No newline at end of file
-- 
GitLab