diff --git a/model_server/base/api.py b/model_server/base/api.py index 2ab37ad3da614c3a54ec1c01250a60d5dbf4fec4..8259a98c8f74a5ca25b6b2b3c6e4d8edc95fa547 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, HTTPException -from model_server.base.models import DummySemanticSegmentationModel +from model_server.base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel from model_server.base.session import Session, InvalidPathError from model_server.base.validators import validate_workflow_inputs from model_server.base.workflows import classify_pixels @@ -65,10 +65,14 @@ def restart_session(root: str = None) -> dict: def list_active_models(): return session.describe_loaded_models() -@app.put('/models/dummy/load/') +@app.put('/models/dummy_semantic/load/') def load_dummy_model() -> dict: return {'model_id': session.load_model(DummySemanticSegmentationModel)} +@app.put('/models/dummy_instance/load/') +def load_dummy_model() -> dict: + return {'model_id': session.load_model(DummyInstanceSegmentationModel)} + @app.put('/workflows/segment') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename diff --git a/tests/test_api.py b/tests/test_api.py index 8036289b0b32b0876c0cbf53732ad36b73a2affd..1b201440da4504e16f10573d31747603926e3b17 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -78,8 +78,8 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.status_code, 200) self.assertEqual(resp.content, b'{}') - def test_load_dummy_model(self): - resp_load = self._put(f'models/dummy/load') + def test_load_dummy_semantic_model(self): + resp_load = self._put(f'models/dummy_semantic/load') model_id = resp_load.json()['model_id'] self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = self._get('models') @@ -88,8 +88,18 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(rj[model_id]['class'], 'DummySemanticSegmentationModel') return model_id + def test_load_dummy_instance_model(self): + resp_load = self._put(f'models/dummy_instance/load') + model_id = resp_load.json()['model_id'] + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + resp_list = self._get('models') + self.assertEqual(resp_list.status_code, 200) + rj = resp_list.json() + self.assertEqual(rj[model_id]['class'], 'DummyInstanceSegmentationModel') + return model_id + def test_respond_with_error_when_invalid_filepath_requested(self): - model_id = self.test_load_dummy_model() + model_id = self.test_load_dummy_semantic_model() resp = self._put( f'infer/from_image_file', @@ -107,7 +117,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.status_code, 409, resp.content.decode()) def test_i2i_dummy_inference_by_api(self): - model_id = self.test_load_dummy_model() + model_id = self.test_load_dummy_semantic_model() self.copy_input_file_to_server() resp_infer = self._put( f'workflows/segment', @@ -120,7 +130,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) def test_restarting_session_clears_loaded_models(self): - resp_load = self._put(f'models/dummy/load',) + resp_load = self._put(f'models/dummy_semantic/load',) self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list_0 = self._get('models') self.assertEqual(resp_list_0.status_code, 200)