diff --git a/model_server/base/api.py b/model_server/base/api.py index 972c909e22ed28c6bdd115d96c10be8ffa3bf9e3..205014b6607c1b9dee743548fbe9a6334872971d 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,11 +1,12 @@ from pydantic import BaseModel, Field -from typing import List, Union +from typing import Dict, List, Union from fastapi import FastAPI, HTTPException from .accessors import generate_file_accessor from .models import BinaryThresholdSegmentationModel +from .pipelines.shared import PipelineRecord from .roiset import IntensityThresholdInstanceMaskSegmentationModel, RoiSetExportParams, SerializeRoiSetError -from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError +from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, RunTaskError, WriteAccessorError app = FastAPI(debug=True) @@ -51,6 +52,7 @@ def show_session_status(): 'paths': session.get_paths(), 'rois': session.list_rois(), 'accessors': session.list_accessors(), + 'tasks': session.tasks.list_tasks(), } @@ -129,7 +131,6 @@ def delete_accessor(accessor_id: str): else: return _session_accessor(session.del_accessor, accessor_id) - @app.put('/accessors/read_from_file/{filename}') def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None): fp = session.paths['inbound_images'] / filename @@ -140,9 +141,9 @@ def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None) @app.put('/accessors/write_to_file/{accessor_id}') -def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None) -> str: +def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None, pop: bool = True) -> str: try: - return session.write_accessor(accessor_id, filename) + return session.write_accessor(accessor_id, filename, pop=pop) except AccessorIdError as e: raise HTTPException(404, f'Did not find accessor with ID {accessor_id}') except WriteAccessorError as e: @@ -202,4 +203,28 @@ def roiset_get_object_map( raise HTTPException( 404, f'Did not find object map from classification model {model_id} in RoiSet {roiset_id}' - ) \ No newline at end of file + ) + +class TaskInfo(BaseModel): + module: str + params: dict + func_str: str + status: str + error: Union[str, None] + result: Union[Dict, None] + +@app.put('/tasks/{task_id}/run') +def run_task(task_id: str) -> PipelineRecord: + try: + return session.tasks.run_task(task_id) + except RunTaskError as e: + raise HTTPException(409, str(e)) + +@app.get('/tasks/{task_id}') +def get_task(task_id: str) -> TaskInfo: + return session.tasks.get_task_info(task_id) + +@app.get('/tasks') +def list_tasks() -> Dict[str, TaskInfo]: + res = session.tasks.list_tasks() + return res \ No newline at end of file diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 8fea5a836d7e66de30063a6149a89dbdf8abfae3..74380a0280d4c777d5f17611c64fd6f5332c0071 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -8,9 +8,10 @@ from .segment_zproj import segment_zproj_pipeline from .shared import call_pipeline from ..roiset import get_label_ids, RoiSet, RoiSetMetaParams, RoiSetExportParams -from ..pipelines.shared import PipelineTrace, PipelineParams, PipelineRecord +from ..pipelines.shared import PipelineQueueRecord, PipelineTrace, PipelineParams, PipelineRecord from ..models import Model, InstanceMaskSegmentationModel +from ..session import session class RoiSetObjectMapParams(PipelineParams): @@ -32,7 +33,7 @@ class RoiSetObjectMapParams(PipelineParams): ) object_classifier_model_id: Union[str, None] = Field( None, - description='Object classifier used to classify segmented objectss' + description='Object classifier used to classify segmented objects' ) patches_channel: int = Field( description='Channel of input image used in patches sent to object classifier' @@ -57,8 +58,9 @@ class RoiSetObjectMapParams(PipelineParams): class RoiSetToObjectMapRecord(PipelineRecord): pass + @router.put('/roiset_to_obmap/infer') -def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: +def roiset_object_map(p: RoiSetObjectMapParams) -> Union[RoiSetToObjectMapRecord, PipelineQueueRecord]: """ Compute a RoiSet from 2d segmentation, apply to z-stack, and optionally apply object classification. """ diff --git a/model_server/base/pipelines/shared.py b/model_server/base/pipelines/shared.py index 3fbf1d5838cd05017240d11678876b76eca3ff03..be12c1f20243b67c269075856f435978923d4375 100644 --- a/model_server/base/pipelines/shared.py +++ b/model_server/base/pipelines/shared.py @@ -4,6 +4,7 @@ from time import perf_counter from typing import List, Union from fastapi import HTTPException +from numba.scripts.generate_lower_listing import description from pydantic import BaseModel, Field, root_validator from ..accessors import GenericImageDataAccessor, InMemoryDataAccessor @@ -12,6 +13,7 @@ from ..session import session, AccessorIdError class PipelineParams(BaseModel): + schedule: bool = Field(False, description='Schedule as a task instead of running immediately') keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session') api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True') @@ -44,7 +46,19 @@ class PipelineRecord(BaseModel): roiset_id: Union[str, None] = None -def call_pipeline(func, p: PipelineParams) -> PipelineRecord: +class PipelineQueueRecord(BaseModel): + task_id: str + +def call_pipeline(func, p: PipelineParams) -> Union[PipelineRecord, PipelineQueueRecord]: + # instead of running right away, schedule pipeline as a task + if p.schedule: + p.schedule = False + task_id = session.tasks.add_task( + lambda x: call_pipeline(func, x), + p + ) + return PipelineQueueRecord(task_id=task_id) + # match accessor IDs to loaded accessor objects accessors_in = {} diff --git a/model_server/base/session.py b/model_server/base/session.py index c2c596ed28cd43d7625e68589824aeeb5c7afa4d..1dfd8f3078d9adba0f103e776821f3691d2adb47 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -1,6 +1,7 @@ from collections import OrderedDict import logging import os +import uuid from pathlib import Path, PureWindowsPath from pydantic import BaseModel @@ -32,6 +33,71 @@ class CsvTable(object): self.empty = False return True + +class TaskCollection(object): + + status_codes = { + 'waiting': 'WAITING', + 'in_progress': 'IN_PROGRESS', + 'finished': 'FINISHED', + 'failed': 'FAILED', + } + + def __init__(self): + self._tasks = OrderedDict() + self._handles = OrderedDict() + + def add_task(self, func: callable, params: dict) -> str: + task_id = str(uuid.uuid4()) + + self._tasks[task_id] = { + 'module': func.__module__, + 'params': params, + 'func_str': str(func), + 'status': self.status_codes['waiting'], + 'error': None, + 'result': None, + } + + self._handles[task_id] = func + logger.info(f'Added task {task_id}: {str(func)}') + + return str(task_id) + + def get_task_info(self, task_id: str) -> dict: + return self._tasks[task_id] + + def list_tasks(self) -> OrderedDict: + return self._tasks + + @property + def next_waiting(self): + """ + Return the task_id of the first status waiting for completion, or else None + """ + for k, v in self._tasks.items(): + if v[k]['status'] == self.status_codes['waiting']: + return k + return None + + def run_task(self, task_id: str): + task = self._tasks[task_id] + f = self._handles[task_id] + p = task['params'] + try: + logger.info(f'Started running task {task_id}') + task['status'] = self.status_codes['in_progress'] + result = f(p) + logger.info(f'Finished running task {task_id}') + task['status'] = self.status_codes['finished'] + task['result'] = result + return result + except Exception as e: + task['status'] = self.status_codes['failed'] + task['error'] = str(e) + logger.error(f'Error running task {task_id}: {str(e)}') + raise RunTaskError(e) + class _Session(object): """ Singleton class for a server session that persists data between API calls @@ -43,6 +109,7 @@ class _Session(object): self.models = {} # model_id : model object self.paths = self.make_paths() self.accessors = OrderedDict() + self.tasks = TaskCollection() self.rois = OrderedDict() self.logfile = self.paths['logs'] / f'session.log' @@ -97,6 +164,7 @@ class _Session(object): idx = len(self.accessors) accessor_id = f'acc_{idx:06d}' self.accessors[accessor_id] = {'loaded': True, 'object': acc, **acc.info} + self.log_info(f'Added accessor {accessor_id}') return accessor_id def del_accessor(self, accessor_id: str) -> str: @@ -114,6 +182,7 @@ class _Session(object): assert isinstance(v['object'], GenericImageDataAccessor) v['loaded'] = False v['object'] = None + self.log_info(f'Deleted accessor {accessor_id}') return accessor_id def del_all_accessors(self) -> list[str]: @@ -127,6 +196,7 @@ class _Session(object): v['object'] = None v['loaded'] = False res.append(k) + self.log_info(f'Deleted accessor {k}') return res @@ -161,11 +231,12 @@ class _Session(object): self.del_accessor(acc_id) return acc - def write_accessor(self, acc_id: str, filename: Union[str, None] = None) -> str: + def write_accessor(self, acc_id: str, filename: Union[str, None] = None, pop: bool = True) -> str: """ - Write an accessor to file and unload it from the session + Write an accessor to file and optionally unload it from the session :param acc_id: accessor's ID :param filename: force use of a specific filename, raise InvalidPathError if this already exists + :param pop: unload accessor from the session if True :return: name of file """ if filename is None: @@ -174,7 +245,7 @@ class _Session(object): fp = self.paths['outbound_images'] / filename if fp.exists(): raise InvalidPathError(f'Cannot overwrite file {filename} when writing accessor') - acc = self.get_accessor(acc_id, pop=True) + acc = self.get_accessor(acc_id, pop=pop) old_fp = self.accessors[acc_id]['filepath'] if old_fp != '': @@ -187,6 +258,7 @@ class _Session(object): else: acc.write(fp) self.accessors[acc_id]['filepath'] = fp.__str__() + self.log_info(f'Wrote accessor {acc_id} to {fp.__str__()}') return fp.name def add_roiset(self, roiset: RoiSet, roiset_id: str = None) -> str: @@ -418,4 +490,7 @@ class CouldNotAppendToTable(Error): pass class InvalidPathError(Error): + pass + +class RunTaskError(Error): pass \ No newline at end of file diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index e5065c1050e3ed5c6a7bf3dc6720ad938e61eaf5..5e6cfb7e002c30adc310929b6a4de23479a4e022 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -15,6 +15,7 @@ from urllib3 import Retry from .fastapi import app from ..base.accessors import GenericImageDataAccessor, InMemoryDataAccessor from ..base.models import SemanticSegmentationModel, InstanceMaskSegmentationModel +from ..base.pipelines.shared import call_pipeline, PipelineParams, PipelineQueueRecord, PipelineTrace from ..base.session import session from ..base.accessors import generate_file_accessor @@ -45,17 +46,38 @@ def load_dummy_accessor() -> str: return session.add_accessor(acc) @test_router.put('/models/dummy_semantic/load/') -def load_dummy_model() -> dict: +def load_dummy_semantic_model() -> dict: mid = session.load_model(DummySemanticSegmentationModel) session.log_info(f'Loaded model {mid}') return {'model_id': mid} @test_router.put('/models/dummy_instance/load/') -def load_dummy_model() -> dict: +def load_dummy_instance_model() -> dict: mid = session.load_model(DummyInstanceMaskSegmentationModel) session.log_info(f'Loaded model {mid}') return {'model_id': mid} + +class DummyTaskParams(PipelineParams): + accessor_id: str + break_me: bool = False + +@test_router.put('/tasks/create_dummy_task') +def create_dummy_task(params: DummyTaskParams) -> PipelineQueueRecord: + def _dummy_pipeline(accessors, models, **k): + d = PipelineTrace(accessors.get('')) + if k.get('break_me'): + raise Exception('I broke') + model = models.get('') + d['res'] = d.last.apply(lambda x: 2 * x) + return d + + task_id = session.tasks.add_task( + lambda x: call_pipeline(_dummy_pipeline, x), + params + ) + return PipelineQueueRecord(task_id=task_id) + app.include_router(test_router) @@ -132,7 +154,13 @@ class TestServerBaseClass(unittest.TestCase): return self.input_data['name'] def get_accessor(self, accessor_id, filename=None, copy_to=None): - r = self.assertPutSuccess(f'/accessors/write_to_file/{accessor_id}', query={'filename': filename}) + r = self.assertPutSuccess( + f'/accessors/write_to_file/{accessor_id}', + query={ + 'filename': filename, + 'pop': False, + }, + ) fp_out = Path(self.assertGetSuccess('paths')['outbound_images']) / r self.assertTrue(fp_out.exists()) if copy_to: diff --git a/tests/base/test_api.py b/tests/base/test_api.py index a5df9e2f85539dc49e5700188395726bb06c9437..52d9456114b61211fdb6ce2f59422b7813a8216a 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -199,4 +199,36 @@ class TestApiFromAutomatedClient(TestServerBaseClass): body={} ) - + def test_run_dummy_task(self): + acc_id = self.assertPutSuccess('/testing/accessors/dummy_accessor/load') + acc_in = self.get_accessor(acc_id) + + task_id = self.assertPutSuccess( + '/testing/tasks/create_dummy_task', + body={'accessor_id': acc_id} + )['task_id'] + task_info1 = self.assertGetSuccess(f'/tasks/{task_id}') + self.assertEqual(task_info1['status'], 'WAITING') + + # run the task and compare results + rec = self.assertPutSuccess(f'/tasks/{task_id}/run') + task_info2 = self.assertGetSuccess(f'/tasks/{task_id}') + self.assertEqual(task_info2['status'], 'FINISHED') + acc_out = self.get_accessor(task_info2['result']['output_accessor_id']) + self.assertTrue((acc_out.data == acc_in.data * 2).all()) + + def test_run_failed_dummy_task(self): + acc_id = self.assertPutSuccess('/testing/accessors/dummy_accessor/load') + acc_in = self.get_accessor(acc_id) + + task_id = self.assertPutSuccess( + '/testing/tasks/create_dummy_task', + body={'accessor_id': acc_id, 'break_me': True} + )['task_id'] + task_info1 = self.assertGetSuccess(f'/tasks/{task_id}') + self.assertEqual(task_info1['status'], 'WAITING') + + # run the task and compare results + rec = self.assertPutFailure(f'/tasks/{task_id}/run', 409) + task_info2 = self.assertGetSuccess(f'/tasks/{task_id}') + self.assertEqual(task_info2['status'], 'FAILED') diff --git a/tests/base/test_roiset_pipeline.py b/tests/base/test_roiset_pipeline.py index a32ec47b07137666479277f384eb4d42ca2e1eb8..f42527722f0745536e63ad685ef893e682827138 100644 --- a/tests/base/test_roiset_pipeline.py +++ b/tests/base/test_roiset_pipeline.py @@ -1,3 +1,4 @@ +import json from pathlib import Path import unittest @@ -106,7 +107,6 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(rois.count, 22) self.assertEqual(len(trace['ob_id'].unique()[0]), 2) - class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): input_data = data['multichannel_zstack_raw'] @@ -183,3 +183,109 @@ class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProd res = self._object_map_workflow(None) acc_obmap = self.get_accessor(res['output_accessor_id']) self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) + +class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): + + input_data = data['multichannel_zstack_raw'] + + + def setUp(self) -> None: + self.where_out = output_path / 'roiset' + self.where_out.mkdir(parents=True, exist_ok=True) + return conf.TestServerBaseClass.setUp(self) + + def test_load_input_accessor(self): + fname = self.copy_input_file_to_server() + return self.assertPutSuccess(f'accessors/read_from_file/{fname}') + + def test_load_pixel_classifier(self): + mid = self.assertPutSuccess( + 'models/seg/threshold/load/', + query={'tr': 0.2}, + )['model_id'] + self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel')) + return mid + + def test_load_object_classifier(self): + mid = self.assertPutSuccess( + 'models/classify/threshold/load/', + body={'tr': 0} + )['model_id'] + self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation')) + return mid + + def _object_map_workflow(self, ob_classifer_id): + res = self.assertPutSuccess( + 'pipelines/roiset_to_obmap/infer', + body={ + 'accessor_id': self.test_load_input_accessor(), + 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), + 'object_classifier_model_id': ob_classifer_id, + 'segmentation': {'channel': 0}, + 'patches_channel': 1, + 'roi_params': self._get_roi_params(), + 'export_params': self._get_export_params(), + }, + ) + + # check on automatically written RoiSet + roiset_id = res['roiset_id'] + roiset_info = self.assertGetSuccess(f'rois/{roiset_id}') + self.assertGreater(roiset_info['count'], 0) + return res + + def test_workflow_with_object_classifier(self): + obmod_id = self.test_load_object_classifier() + res = self._object_map_workflow(obmod_id) + acc_obmap = self.get_accessor(res['output_accessor_id']) + self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) + + # get object map via RoiSet API + roiset_id = res['roiset_id'] + obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True}) + acc_obmap_roiset = self.get_accessor(obmap_id) + self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data)) + + # check serialize RoiSet + self.assertPutSuccess(f'rois/write/{roiset_id}') + self.assertFalse( + self.assertGetSuccess(f'rois/{roiset_id}')['loaded'] + ) + + + def test_workflow_without_object_classifier(self): + res = self._object_map_workflow(None) + acc_obmap = self.get_accessor(res['output_accessor_id']) + self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) + + +class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi): + def _object_map_workflow(self, ob_classifer_id): + + res_queue = self.assertPutSuccess( + 'pipelines/roiset_to_obmap/infer', + body={ + 'schedule': True, + 'accessor_id': self.test_load_input_accessor(), + 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), + 'object_classifier_model_id': ob_classifer_id, + 'segmentation': {'channel': 0}, + 'patches_channel': 1, + 'roi_params': self._get_roi_params(), + 'export_params': self._get_export_params(), + } + ) + + # check that task in enqueued + task_id = res_queue['task_id'] + task_info = self.assertGetSuccess(f'tasks/{task_id}') + self.assertEqual(task_info['status'], 'WAITING') + + # run the task + res_run = self.assertPutSuccess( + f'tasks/{task_id}/run' + ) + self.assertTrue(res_run) + self.assertEqual(self.assertGetSuccess(f'tasks/{task_id}')['status'], 'FINISHED') + + return res_run \ No newline at end of file