From b219399a3ff8efe36fca2567377aec9fa3330f5f Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 7 Oct 2024 16:15:23 +0200
Subject: [PATCH] Test server now has "dummy" endpoints configured, too

---
 model_server/conf/testing.py | 46 ++++++++++++++++++++++++++-
 tests/base/test_api.py       | 60 +++---------------------------------
 2 files changed, 49 insertions(+), 57 deletions(-)

diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py
index 042a9551..963527d7 100644
--- a/model_server/conf/testing.py
+++ b/model_server/conf/testing.py
@@ -6,21 +6,65 @@ from multiprocessing import Process
 from pathlib import Path
 from shutil import copyfile
 
+from fastapi import APIRouter
 import numpy as np
+from pydantic import BaseModel
 import requests
 from urllib3 import Retry
 
+from .fastapi import app
 from ..base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
 from ..base.models import SemanticSegmentationModel, InstanceSegmentationModel
+from ..base.session import session
 
 from ..base.accessors import generate_file_accessor
 
+"""
+Configure additional endpoints for testing
+"""
+test_router = APIRouter(prefix='/testing', tags=['testing'])
+
+class BounceBackParams(BaseModel):
+    par1: str
+    par2: list
+
+@test_router.put('/bounce_back')
+def list_bounce_back(params: BounceBackParams):
+    return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}}
+
+@test_router.put('/accessors/dummy_accessor/load')
+def load_dummy_accessor() -> str:
+    acc = InMemoryDataAccessor(
+        np.random.randint(
+            0,
+            2 ** 8,
+            size=(512, 256, 3, 7),
+            dtype='uint8'
+        )
+    )
+    return session.add_accessor(acc)
+
+@test_router.put('/models/dummy_semantic/load/')
+def load_dummy_model() -> dict:
+    mid = session.load_model(DummySemanticSegmentationModel)
+    session.log_info(f'Loaded model {mid}')
+    return {'model_id': mid}
+
+@test_router.put('/models/dummy_instance/load/')
+def load_dummy_model() -> dict:
+    mid = session.load_model(DummyInstanceSegmentationModel)
+    session.log_info(f'Loaded model {mid}')
+    return {'model_id': mid}
+
+app.include_router(test_router)
+
+
 class TestServerBaseClass(unittest.TestCase):
     """
     Base class for unittests of API functionality.  Implements both server and clients for testing.
     """
 
-    app_name = 'model_server.base.api:app'
+    app_name = 'model_server.conf.testing:app'
 
     def setUp(self) -> None:
         import uvicorn
diff --git a/tests/base/test_api.py b/tests/base/test_api.py
index e31b4366..74c6a79e 100644
--- a/tests/base/test_api.py
+++ b/tests/base/test_api.py
@@ -1,67 +1,15 @@
 from pathlib import Path
 
-from fastapi import APIRouter, FastAPI
-import numpy as np
-from pydantic import BaseModel
-
 import model_server.conf.testing as conf
-from model_server.base.accessors import InMemoryDataAccessor
-from model_server.base.api import app
-from model_server.base.session import session
-from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceSegmentationModel
+from model_server.conf.testing import TestServerBaseClass
 
 czifile = conf.meta['image_files']['czifile']
 
-"""
-Configure additional endpoints for testing
-"""
-test_router = APIRouter(prefix='/testing', tags=['testing'])
-
-class BounceBackParams(BaseModel):
-    par1: str
-    par2: list
-
-@test_router.put('/bounce_back')
-def list_bounce_back(params: BounceBackParams):
-    return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}}
-
-@test_router.put('/accessors/dummy_accessor/load')
-def load_dummy_accessor() -> str:
-    acc = InMemoryDataAccessor(
-        np.random.randint(
-            0,
-            2 ** 8,
-            size=(512, 256, 3, 7),
-            dtype='uint8'
-        )
-    )
-    return session.add_accessor(acc)
-
-@test_router.put('/models/dummy_semantic/load/')
-def load_dummy_model() -> dict:
-    mid = session.load_model(DummySemanticSegmentationModel)
-    session.log_info(f'Loaded model {mid}')
-    return {'model_id': mid}
-
-@test_router.put('/models/dummy_instance/load/')
-def load_dummy_model() -> dict:
-    mid = session.load_model(DummyInstanceSegmentationModel)
-    session.log_info(f'Loaded model {mid}')
-    return {'model_id': mid}
-
-app.include_router(test_router)
 
-"""
-Implement unit testing on extended base app
-"""
-
-
-class TestServerTestCase(conf.TestServerBaseClass):
-    app_name = 'tests.base.test_api:app'
+class TestApiFromAutomatedClient(TestServerBaseClass):
+    
     input_data = czifile
-
-
-class TestApiFromAutomatedClient(TestServerTestCase):
+    
     def test_trivial_api_response(self):
         resp = self._get('')
         self.assertEqual(resp.status_code, 200)
-- 
GitLab