From a62aa5a575a67685689da0ddf0d197744ecd52b4 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 28 Aug 2023 11:54:52 +0200
Subject: [PATCH] Session now recursively searches for all subclasses of Model
 when loading a model

---
 model_server/model.py   | 11 +++++++++++
 model_server/session.py |  9 ++++++++-
 tests/test_api.py       |  9 +++++----
 tests/test_model.py     |  9 +++++++--
 tests/test_session.py   |  3 ++-
 5 files changed, 33 insertions(+), 8 deletions(-)

diff --git a/model_server/model.py b/model_server/model.py
index 563b85d9..7c9d84e1 100644
--- a/model_server/model.py
+++ b/model_server/model.py
@@ -22,6 +22,17 @@ class Model(ABC):
             raise CouldNotLoadModelError()
         return None
 
+    @classmethod
+    def get_all_subclasses(cls):
+        """
+        Recursively find all subclasses of Model
+        :return: set of all subclasses of Model
+        """
+        def get_all_subclasses_of(cc):
+            return set(cc.__subclasses__()).union(
+                [s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)])
+        return get_all_subclasses_of(cls)
+
     @abstractmethod
     def load(self):
         """
diff --git a/model_server/session.py b/model_server/session.py
index 3f4fa30a..66a3ac11 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -67,12 +67,16 @@ class Session(object):
         :param model_id:
         :return: True if model successfully loaded, False if not
         """
-        for mc in Model.__subclasses__():
+        models = Model.get_all_subclasses()
+        for mc in models:
             if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
                 mi = mc()
                 assert mi.loaded
                 self.models['model_id'] = mi
                 return True
+        raise CouldNotFindModelError(
+            f'Could not find {model_id} in:\n{models}',
+        )
         return False
 
     def describe_models(self) -> dict:
@@ -87,5 +91,8 @@ class Session(object):
 class Error(Exception):
     pass
 
+class CouldNotFindModelError(Error):
+    pass
+
 class InferenceRecordError(Error):
     pass
\ No newline at end of file
diff --git a/tests/test_api.py b/tests/test_api.py
index 4071c023..17afe2e3 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -34,10 +34,11 @@ class TestApiFromAutomatedClient(unittest.TestCase):
 
     def test_load_model(self):
         model_id = DummyImageToImageModel.model_id
-        resp = requests.get(self.uri + f'models/{model_id}/load')
-        self.assertEqual(resp.status_code, 200)
-        loaded = requests.get(self.uri + 'models')
-        self.assertEqual(loaded.content, b'{"model_id":"DummyImageToImageModel"}')
+        resp_load = requests.get(self.uri + f'models/{model_id}/load')
+        self.assertEqual(resp_load.status_code, 200)
+        resp_list = requests.get(self.uri + 'models')
+        self.assertEqual(resp_list.status_code, 200)
+        self.assertEqual(resp_list.content, b'{"model_id":"DummyImageToImageModel"}')
 
     def test_i2i_inference_errors_model_not_sound(self):
         model_id = 'not_a_real_model'
diff --git a/tests/test_model.py b/tests/test_model.py
index aff9db86..0ca0938e 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -1,7 +1,7 @@
 import unittest
 from conf.testing import czifile
 from model_server.image import CziImageFileAccessor
-from model_server.model import DummyImageToImageModel, CouldNotLoadModelError
+from model_server.model import DummyImageToImageModel, CouldNotLoadModelError, Model
 
 class TestCziImageFileAccess(unittest.TestCase):
     def setUp(self) -> None:
@@ -48,4 +48,9 @@ class TestCziImageFileAccess(unittest.TestCase):
             img[0, 0],
             0,
             'First pixel is not black as expected'
-        )
\ No newline at end of file
+        )
+
+    def test_find_subclasses_recursively(self):
+        sc = DummyImageToImageModel
+        scs = Model.get_all_subclasses()
+        self.assertIn(DummyImageToImageModel, scs)
\ No newline at end of file
diff --git a/tests/test_session.py b/tests/test_session.py
index 924f0cf6..09e0b1cc 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -35,6 +35,7 @@ class TestGetSessionObject(unittest.TestCase):
 
     def test_session_load_model(self):
         sesh = Session()
-        self.assertTrue(sesh.load_model(DummyImageToImageModel.model_id))
+        success = sesh.load_model(DummyImageToImageModel.model_id)
+        self.assertTrue(success)
         self.assertTrue('model_id' in sesh.models.keys())
         self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel)
\ No newline at end of file
-- 
GitLab