From 5865f0aa2ff9c2a7dc9637e2464d7c52ecb90d4b Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 4 Mar 2024 19:15:00 +0100 Subject: [PATCH] Mapped dummy model for instance segmentation to API as well, for testing purposes --- model_server/base/api.py | 8 ++++++-- tests/test_api.py | 20 +++++++++++++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/model_server/base/api.py b/model_server/base/api.py index 2ab37ad3..8259a98c 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, HTTPException -from model_server.base.models import DummySemanticSegmentationModel +from model_server.base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel from model_server.base.session import Session, InvalidPathError from model_server.base.validators import validate_workflow_inputs from model_server.base.workflows import classify_pixels @@ -65,10 +65,14 @@ def restart_session(root: str = None) -> dict: def list_active_models(): return session.describe_loaded_models() -@app.put('/models/dummy/load/') +@app.put('/models/dummy_semantic/load/') def load_dummy_model() -> dict: return {'model_id': session.load_model(DummySemanticSegmentationModel)} +@app.put('/models/dummy_instance/load/') +def load_dummy_model() -> dict: + return {'model_id': session.load_model(DummyInstanceSegmentationModel)} + @app.put('/workflows/segment') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename diff --git a/tests/test_api.py b/tests/test_api.py index 8036289b..1b201440 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -78,8 +78,8 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.status_code, 200) self.assertEqual(resp.content, b'{}') - def test_load_dummy_model(self): - resp_load = self._put(f'models/dummy/load') + def test_load_dummy_semantic_model(self): + resp_load = self._put(f'models/dummy_semantic/load') model_id = resp_load.json()['model_id'] self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = self._get('models') @@ -88,8 +88,18 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(rj[model_id]['class'], 'DummySemanticSegmentationModel') return model_id + def test_load_dummy_instance_model(self): + resp_load = self._put(f'models/dummy_instance/load') + model_id = resp_load.json()['model_id'] + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + resp_list = self._get('models') + self.assertEqual(resp_list.status_code, 200) + rj = resp_list.json() + self.assertEqual(rj[model_id]['class'], 'DummyInstanceSegmentationModel') + return model_id + def test_respond_with_error_when_invalid_filepath_requested(self): - model_id = self.test_load_dummy_model() + model_id = self.test_load_dummy_semantic_model() resp = self._put( f'infer/from_image_file', @@ -107,7 +117,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.status_code, 409, resp.content.decode()) def test_i2i_dummy_inference_by_api(self): - model_id = self.test_load_dummy_model() + model_id = self.test_load_dummy_semantic_model() self.copy_input_file_to_server() resp_infer = self._put( f'workflows/segment', @@ -120,7 +130,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) def test_restarting_session_clears_loaded_models(self): - resp_load = self._put(f'models/dummy/load',) + resp_load = self._put(f'models/dummy_semantic/load',) self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list_0 = self._get('models') self.assertEqual(resp_list_0.status_code, 200) -- GitLab