From 5865f0aa2ff9c2a7dc9637e2464d7c52ecb90d4b Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 4 Mar 2024 19:15:00 +0100
Subject: [PATCH] Mapped dummy model for instance segmentation to API as well,
 for testing purposes

---
 model_server/base/api.py |  8 ++++++--
 tests/test_api.py        | 20 +++++++++++++++-----
 2 files changed, 21 insertions(+), 7 deletions(-)

diff --git a/model_server/base/api.py b/model_server/base/api.py
index 2ab37ad3..8259a98c 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -1,6 +1,6 @@
 from fastapi import FastAPI, HTTPException
 
-from model_server.base.models import DummySemanticSegmentationModel
+from model_server.base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
 from model_server.base.session import Session, InvalidPathError
 from model_server.base.validators import validate_workflow_inputs
 from model_server.base.workflows import classify_pixels
@@ -65,10 +65,14 @@ def restart_session(root: str = None) -> dict:
 def list_active_models():
     return session.describe_loaded_models()
 
-@app.put('/models/dummy/load/')
+@app.put('/models/dummy_semantic/load/')
 def load_dummy_model() -> dict:
     return {'model_id': session.load_model(DummySemanticSegmentationModel)}
 
+@app.put('/models/dummy_instance/load/')
+def load_dummy_model() -> dict:
+    return {'model_id': session.load_model(DummyInstanceSegmentationModel)}
+
 @app.put('/workflows/segment')
 def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
     inpath = session.paths['inbound_images'] / input_filename
diff --git a/tests/test_api.py b/tests/test_api.py
index 8036289b..1b201440 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -78,8 +78,8 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         self.assertEqual(resp.status_code, 200)
         self.assertEqual(resp.content, b'{}')
 
-    def test_load_dummy_model(self):
-        resp_load = self._put(f'models/dummy/load')
+    def test_load_dummy_semantic_model(self):
+        resp_load = self._put(f'models/dummy_semantic/load')
         model_id = resp_load.json()['model_id']
         self.assertEqual(resp_load.status_code, 200, resp_load.json())
         resp_list = self._get('models')
@@ -88,8 +88,18 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         self.assertEqual(rj[model_id]['class'], 'DummySemanticSegmentationModel')
         return model_id
 
+    def test_load_dummy_instance_model(self):
+        resp_load = self._put(f'models/dummy_instance/load')
+        model_id = resp_load.json()['model_id']
+        self.assertEqual(resp_load.status_code, 200, resp_load.json())
+        resp_list = self._get('models')
+        self.assertEqual(resp_list.status_code, 200)
+        rj = resp_list.json()
+        self.assertEqual(rj[model_id]['class'], 'DummyInstanceSegmentationModel')
+        return model_id
+
     def test_respond_with_error_when_invalid_filepath_requested(self):
-        model_id = self.test_load_dummy_model()
+        model_id = self.test_load_dummy_semantic_model()
 
         resp = self._put(
             f'infer/from_image_file',
@@ -107,7 +117,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         self.assertEqual(resp.status_code, 409, resp.content.decode())
 
     def test_i2i_dummy_inference_by_api(self):
-        model_id = self.test_load_dummy_model()
+        model_id = self.test_load_dummy_semantic_model()
         self.copy_input_file_to_server()
         resp_infer = self._put(
             f'workflows/segment',
@@ -120,7 +130,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
 
     def test_restarting_session_clears_loaded_models(self):
-        resp_load = self._put(f'models/dummy/load',)
+        resp_load = self._put(f'models/dummy_semantic/load',)
         self.assertEqual(resp_load.status_code, 200, resp_load.json())
         resp_list_0 = self._get('models')
         self.assertEqual(resp_list_0.status_code, 200)
-- 
GitLab