From ff6e182dc7eaeb021ebf78cac63882ffd366ff4f Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 9 Nov 2024 06:47:26 +0100 Subject: [PATCH] Implemented queue, basic endpoints, and first pipeline example. Did not yet test execution. --- model_server/base/api.py | 24 ++++- model_server/base/pipelines/roiset_obmap.py | 12 ++- model_server/base/pipelines/shared.py | 4 + model_server/base/session.py | 53 +++++++++++ tests/base/test_api.py | 4 +- tests/base/test_roiset_pipeline.py | 99 ++++++++++++++++++++- 6 files changed, 188 insertions(+), 8 deletions(-) diff --git a/model_server/base/api.py b/model_server/base/api.py index 972c909e..9ef95fe6 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import List, Union +from typing import Dict, List, Union from fastapi import FastAPI, HTTPException from .accessors import generate_file_accessor @@ -51,6 +51,7 @@ def show_session_status(): 'paths': session.get_paths(), 'rois': session.list_rois(), 'accessors': session.list_accessors(), + 'tasks': session.queue.list_tasks(), } @@ -202,4 +203,23 @@ def roiset_get_object_map( raise HTTPException( 404, f'Did not find object map from classification model {model_id} in RoiSet {roiset_id}' - ) \ No newline at end of file + ) + +class TaskInfo(BaseModel): + module: str + params: dict + func_str: str + status: str + error: Union[str, None] + result: Union[Dict, None] + +# TODO: cover is API tests, with dummy task resource endpoint + +@app.get('/tasks/{task_id') +def get_task(task_id: str) -> TaskInfo: + return session.queue.get_task_info(task_id) + +@app.get('/tasks') +def list_tasks() -> Dict[str, TaskInfo]: + res = session.queue.list_tasks() + return res \ No newline at end of file diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 8fea5a83..c7163b80 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -8,9 +8,10 @@ from .segment_zproj import segment_zproj_pipeline from .shared import call_pipeline from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams -from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord +from ..pipelines.shared import PipelineQueueRecord, PipelineTrace, PipelineParams, PipelineRecord from ..models import Model, InstanceMaskSegmentationModel +from ..session import session class RoiSetObjectMapParams(PipelineParams): @@ -32,7 +33,7 @@ class RoiSetObjectMapParams(PipelineParams): ) object_classifier_model_id: Union[str, None] = Field( None, - description='Object classifier used to classify segmented objectss' + description='Object classifier used to classify segmented objects' ) patches_channel: int = Field( description='Channel of input image used in patches sent to object classifier' @@ -64,6 +65,13 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: """ return call_pipeline(roiset_object_map_pipeline, p) +@router.put('/queue/roiset_to_obmap') +def queue_roiset_object_map(p: RoiSetObjectMapParams) -> PipelineQueueRecord: + task_id = session.queue.add_task( + lambda x: call_pipeline(roiset_object_map_pipeline, x), + p + ) + return {'task_id': task_id} def roiset_object_map_pipeline( accessors: Dict[str, GenericImageDataAccessor], diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py index 3fbf1d58..73b797a6 100644 --- a/model_server/base/pipelines/shared.py +++ b/model_server/base/pipelines/shared.py @@ -44,6 +44,10 @@ class PipelineRecord(BaseModel): roiset_id: Union[str, None] = None +class PipelineQueueRecord(BaseModel): + task_id: str + + def call_pipeline(func, p: PipelineParams) -> PipelineRecord: # match accessor IDs to loaded accessor objects accessors_in = {} diff --git a/model_server/base/session.py b/model_server/base/session.py index c2c596ed..202b33de 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -1,6 +1,7 @@ from collections import OrderedDict import logging import os +import uuid from pathlib import Path, PureWindowsPath from pydantic import BaseModel @@ -32,6 +33,57 @@ class CsvTable(object): self.empty = False return True +# TODO: is there a standard library alternative here? +class Queue(object): + + status_codes = { + 'waiting': 'WAITING', + 'in_progress': 'IN_PROGRESS', + 'finished': 'FINISHED', + 'failed': 'FAILED', + } + + def __init__(self): + self._queue = OrderedDict() + + def add_task(self, func: callable, params: dict) -> str: + task_id = str(uuid.uuid4()) + name = func.__name__ + + self._queue[task_id] = { + 'module': func.__module__, + 'params': params, + 'func_str': str(func), + 'status': self.status_codes['waiting'], + 'error': None, + 'result': None, + } + + return str(task_id) + + def get_task_info(self, task_id: str) -> dict: + return self._queue[task_id] + + def list_tasks(self) -> OrderedDict: + return self._queue + + def run_task(self, task_id: str): + task = self._queue[task_id] + f = task['func'] + p = task['params'] + try: + task['status'] = self.status_codes['in_progress'] + task['result'] = f(p) + task['status'] = self.status_codes['finished'] + return True + except Exception as e: + task['status'] = self.status_codes['failed'] + task['error'] = e + return False + + + + class _Session(object): """ Singleton class for a server session that persists data between API calls @@ -43,6 +95,7 @@ class _Session(object): self.models = {} # model_id : model object self.paths = self.make_paths() self.accessors = OrderedDict() + self.queue = Queue() self.rois = OrderedDict() self.logfile = self.paths['logs'] / f'session.log' diff --git a/tests/base/test_api.py b/tests/base/test_api.py index a5df9e2f..8e2ba825 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -197,6 +197,4 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertPutSuccess( '/models/classify/threshold/load', body={} - ) - - + ) \ No newline at end of file diff --git a/tests/base/test_roiset_pipeline.py b/tests/base/test_roiset_pipeline.py index a32ec47b..0af1ebf2 100644 --- a/tests/base/test_roiset_pipeline.py +++ b/tests/base/test_roiset_pipeline.py @@ -1,3 +1,4 @@ +import json from pathlib import Path import unittest @@ -106,7 +107,6 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(rois.count, 22) self.assertEqual(len(trace['ob_id'].unique()[0]), 2) - class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): input_data = data['multichannel_zstack_raw'] @@ -183,3 +183,100 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd res = self._object_map_workflow(None) acc_obmap = self.get_accessor(res['output_accessor_id']) self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) + +class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): + + input_data = data['multichannel_zstack_raw'] + + + def setUp(self) -> None: + self.where_out = output_path / 'roiset' + self.where_out.mkdir(parents=True, exist_ok=True) + return conf.TestServerBaseClass.setUp(self) + + def test_load_input_accessor(self): + fname = self.copy_input_file_to_server() + return self.assertPutSuccess(f'accessors/read_from_file/{fname}') + + def test_load_pixel_classifier(self): + mid = self.assertPutSuccess( + 'models/seg/threshold/load/', + query={'tr': 0.2}, + )['model_id'] + self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel')) + return mid + + def test_load_object_classifier(self): + mid = self.assertPutSuccess( + 'models/classify/threshold/load/', + body={'tr': 0} + )['model_id'] + self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation')) + return mid + + def _object_map_workflow(self, ob_classifer_id): + res = self.assertPutSuccess( + 'pipelines/roiset_to_obmap/infer', + body={ + 'accessor_id': self.test_load_input_accessor(), + 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), + 'object_classifier_model_id': ob_classifer_id, + 'segmentation': {'channel': 0}, + 'patches_channel': 1, + 'roi_params': self._get_roi_params(), + 'export_params': self._get_export_params(), + }, + ) + + # check on automatically written RoiSet + roiset_id = res['roiset_id'] + roiset_info = self.assertGetSuccess(f'rois/{roiset_id}') + self.assertGreater(roiset_info['count'], 0) + return res + + def test_workflow_with_object_classifier(self): + obmod_id = self.test_load_object_classifier() + res = self._object_map_workflow(obmod_id) + acc_obmap = self.get_accessor(res['output_accessor_id']) + self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) + + # get object map via RoiSet API + roiset_id = res['roiset_id'] + obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True}) + acc_obmap_roiset = self.get_accessor(obmap_id) + self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data)) + + # check serialize RoiSet + self.assertPutSuccess(f'rois/write/{roiset_id}') + self.assertFalse( + self.assertGetSuccess(f'rois/{roiset_id}')['loaded'] + ) + + + def test_workflow_without_object_classifier(self): + res = self._object_map_workflow(None) + acc_obmap = self.get_accessor(res['output_accessor_id']) + self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) + + +class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi): + def _object_map_workflow(self, ob_classifer_id): + + res = self.assertPutSuccess( + 'pipelines/queue/roiset_to_obmap', + body={ + 'accessor_id': self.test_load_input_accessor(), + 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), + 'object_classifier_model_id': ob_classifer_id, + 'segmentation': {'channel': 0}, + 'patches_channel': 1, + 'roi_params': self._get_roi_params(), + 'export_params': self._get_export_params(), + } + ) + + tasks = self.assertGetSuccess('tasks') + + # check on enqueued task + task_id = res['task_id'] + return res \ No newline at end of file -- GitLab