diff --git a/model_server/base/api.py b/model_server/base/api.py index d91ed5854350d1567fa68ae58c4f72f751063864..1a533b92256f5c2d61ac054186aee55dfd807def 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -6,7 +6,7 @@ from .accessors import generate_file_accessor from .models import BinaryThresholdSegmentationModel from .pipelines.shared import PipelineRecord from .roiset import IntensityThresholdInstanceMaskSegmentationModel, RoiSetExportParams, SerializeRoiSetError -from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError +from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, RunTaskError, WriteAccessorError app = FastAPI(debug=True) @@ -220,7 +220,10 @@ class TaskInfo(BaseModel): @app.put('/tasks/{task_id}/run') def run_task(task_id: str) -> PipelineRecord: - return session.queue.run_task(task_id) + try: + return session.queue.run_task(task_id) + except RunTaskError as e: + raise HTTPException(409, str(e)) @app.get('/tasks/{task_id}') def get_task(task_id: str) -> TaskInfo: diff --git a/model_server/base/session.py b/model_server/base/session.py index 12080c676816d009607d87a6250f79b7b7359c91..9f55935ee3bf02f73a9aacb6b012fcffc5e516a0 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -82,8 +82,8 @@ class Queue(object): return result except Exception as e: task['status'] = self.status_codes['failed'] - task['error'] = e - raise e + task['error'] = str(e) + raise RunTaskError(e) class _Session(object): """ @@ -472,4 +472,7 @@ class CouldNotAppendToTable(Error): pass class InvalidPathError(Error): + pass + +class RunTaskError(Error): pass \ No newline at end of file diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 6682a4c4bd273610273b2f26ac44f9d3083380bf..54d82907aea65743154231f108ee90be2483cc4b 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -60,11 +60,14 @@ def load_dummy_instance_model() -> dict: class DummyTaskParams(PipelineParams): accessor_id: str + break_me: bool = False @test_router.put('/tasks/create_dummy_task') def create_dummy_task(params: DummyTaskParams) -> PipelineQueueRecord: def _dummy_pipeline(accessors, models, **k): d = PipelineTrace(accessors.get('')) + if k.get('break_me'): + raise Exception('I broke') model = models.get('') d['res'] = d.last.apply(lambda x: 2 * x) return d diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 8674fe86d8789c67c8c296ce363fba2a212a0b1c..52d9456114b61211fdb6ce2f59422b7813a8216a 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -202,6 +202,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): def test_run_dummy_task(self): acc_id = self.assertPutSuccess('/testing/accessors/dummy_accessor/load') acc_in = self.get_accessor(acc_id) + task_id = self.assertPutSuccess( '/testing/tasks/create_dummy_task', body={'accessor_id': acc_id} @@ -209,10 +210,25 @@ class TestApiFromAutomatedClient(TestServerBaseClass): task_info1 = self.assertGetSuccess(f'/tasks/{task_id}') self.assertEqual(task_info1['status'], 'WAITING') + # run the task and compare results rec = self.assertPutSuccess(f'/tasks/{task_id}/run') - task_info2 = self.assertGetSuccess(f'/tasks/{task_id}') self.assertEqual(task_info2['status'], 'FINISHED') acc_out = self.get_accessor(task_info2['result']['output_accessor_id']) self.assertTrue((acc_out.data == acc_in.data * 2).all()) + def test_run_failed_dummy_task(self): + acc_id = self.assertPutSuccess('/testing/accessors/dummy_accessor/load') + acc_in = self.get_accessor(acc_id) + + task_id = self.assertPutSuccess( + '/testing/tasks/create_dummy_task', + body={'accessor_id': acc_id, 'break_me': True} + )['task_id'] + task_info1 = self.assertGetSuccess(f'/tasks/{task_id}') + self.assertEqual(task_info1['status'], 'WAITING') + + # run the task and compare results + rec = self.assertPutFailure(f'/tasks/{task_id}/run', 409) + task_info2 = self.assertGetSuccess(f'/tasks/{task_id}') + self.assertEqual(task_info2['status'], 'FAILED')