From 5c7976268df92e7380a261f1159802fa602e3228 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 9 Nov 2024 09:08:33 +0100 Subject: [PATCH] Test covers task API --- model_server/conf/testing.py | 31 ++++++++++++++++++++++++++++--- tests/base/test_api.py | 20 +++++++++++++++++++- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index e5065c10..6682a4c4 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -15,6 +15,7 @@ from urllib3 import Retry from .fastapi import app from ..base.accessors import GenericImageDataAccessor, InMemoryDataAccessor from ..base.models import SemanticSegmentationModel, InstanceMaskSegmentationModel +from ..base.pipelines.shared import call_pipeline, PipelineParams, PipelineQueueRecord, PipelineTrace from ..base.session import session from ..base.accessors import generate_file_accessor @@ -45,17 +46,35 @@ def load_dummy_accessor() -> str: return session.add_accessor(acc) @test_router.put('/models/dummy_semantic/load/') -def load_dummy_model() -> dict: +def load_dummy_semantic_model() -> dict: mid = session.load_model(DummySemanticSegmentationModel) session.log_info(f'Loaded model {mid}') return {'model_id': mid} @test_router.put('/models/dummy_instance/load/') -def load_dummy_model() -> dict: +def load_dummy_instance_model() -> dict: mid = session.load_model(DummyInstanceMaskSegmentationModel) session.log_info(f'Loaded model {mid}') return {'model_id': mid} + +class DummyTaskParams(PipelineParams): + accessor_id: str + +@test_router.put('/tasks/create_dummy_task') +def create_dummy_task(params: DummyTaskParams) -> PipelineQueueRecord: + def _dummy_pipeline(accessors, models, **k): + d = PipelineTrace(accessors.get('')) + model = models.get('') + d['res'] = d.last.apply(lambda x: 2 * x) + return d + + task_id = session.queue.add_task( + lambda x: call_pipeline(_dummy_pipeline, x), + params + ) + return PipelineQueueRecord(task_id=task_id) + app.include_router(test_router) @@ -132,7 +151,13 @@ class TestServerBaseClass(unittest.TestCase): return self.input_data['name'] def get_accessor(self, accessor_id, filename=None, copy_to=None): - r = self.assertPutSuccess(f'/accessors/write_to_file/{accessor_id}', query={'filename': filename}) + r = self.assertPutSuccess( + f'/accessors/write_to_file/{accessor_id}', + query={ + 'filename': filename, + 'pop': False, + }, + ) fp_out = Path(self.assertGetSuccess('paths')['outbound_images']) / r self.assertTrue(fp_out.exists()) if copy_to: diff --git a/tests/base/test_api.py b/tests/base/test_api.py index 8e2ba825..8674fe86 100644 --- a/tests/base/test_api.py +++ b/tests/base/test_api.py @@ -197,4 +197,22 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertPutSuccess( '/models/classify/threshold/load', body={} - ) \ No newline at end of file + ) + + def test_run_dummy_task(self): + acc_id = self.assertPutSuccess('/testing/accessors/dummy_accessor/load') + acc_in = self.get_accessor(acc_id) + task_id = self.assertPutSuccess( + '/testing/tasks/create_dummy_task', + body={'accessor_id': acc_id} + )['task_id'] + task_info1 = self.assertGetSuccess(f'/tasks/{task_id}') + self.assertEqual(task_info1['status'], 'WAITING') + + rec = self.assertPutSuccess(f'/tasks/{task_id}/run') + + task_info2 = self.assertGetSuccess(f'/tasks/{task_id}') + self.assertEqual(task_info2['status'], 'FINISHED') + acc_out = self.get_accessor(task_info2['result']['output_accessor_id']) + self.assertTrue((acc_out.data == acc_in.data * 2).all()) + -- GitLab