From a62aa5a575a67685689da0ddf0d197744ecd52b4 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 28 Aug 2023 11:54:52 +0200 Subject: [PATCH] Session now recursively searches for all subclasses of Model when loading a model --- model_server/model.py | 11 +++++++++++ model_server/session.py | 9 ++++++++- tests/test_api.py | 9 +++++---- tests/test_model.py | 9 +++++++-- tests/test_session.py | 3 ++- 5 files changed, 33 insertions(+), 8 deletions(-) diff --git a/model_server/model.py b/model_server/model.py index 563b85d9..7c9d84e1 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -22,6 +22,17 @@ class Model(ABC): raise CouldNotLoadModelError() return None + @classmethod + def get_all_subclasses(cls): + """ + Recursively find all subclasses of Model + :return: set of all subclasses of Model + """ + def get_all_subclasses_of(cc): + return set(cc.__subclasses__()).union( + [s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)]) + return get_all_subclasses_of(cls) + @abstractmethod def load(self): """ diff --git a/model_server/session.py b/model_server/session.py index 3f4fa30a..66a3ac11 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -67,12 +67,16 @@ class Session(object): :param model_id: :return: True if model successfully loaded, False if not """ - for mc in Model.__subclasses__(): + models = Model.get_all_subclasses() + for mc in models: if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: mi = mc() assert mi.loaded self.models['model_id'] = mi return True + raise CouldNotFindModelError( + f'Could not find {model_id} in:\n{models}', + ) return False def describe_models(self) -> dict: @@ -87,5 +91,8 @@ class Session(object): class Error(Exception): pass +class CouldNotFindModelError(Error): + pass + class InferenceRecordError(Error): pass \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index 4071c023..17afe2e3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -34,10 +34,11 @@ class TestApiFromAutomatedClient(unittest.TestCase): def test_load_model(self): model_id = DummyImageToImageModel.model_id - resp = requests.get(self.uri + f'models/{model_id}/load') - self.assertEqual(resp.status_code, 200) - loaded = requests.get(self.uri + 'models') - self.assertEqual(loaded.content, b'{"model_id":"DummyImageToImageModel"}') + resp_load = requests.get(self.uri + f'models/{model_id}/load') + self.assertEqual(resp_load.status_code, 200) + resp_list = requests.get(self.uri + 'models') + self.assertEqual(resp_list.status_code, 200) + self.assertEqual(resp_list.content, b'{"model_id":"DummyImageToImageModel"}') def test_i2i_inference_errors_model_not_sound(self): model_id = 'not_a_real_model' diff --git a/tests/test_model.py b/tests/test_model.py index aff9db86..0ca0938e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,7 +1,7 @@ import unittest from conf.testing import czifile from model_server.image import CziImageFileAccessor -from model_server.model import DummyImageToImageModel, CouldNotLoadModelError +from model_server.model import DummyImageToImageModel, CouldNotLoadModelError, Model class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: @@ -48,4 +48,9 @@ class TestCziImageFileAccess(unittest.TestCase): img[0, 0], 0, 'First pixel is not black as expected' - ) \ No newline at end of file + ) + + def test_find_subclasses_recursively(self): + sc = DummyImageToImageModel + scs = Model.get_all_subclasses() + self.assertIn(DummyImageToImageModel, scs) \ No newline at end of file diff --git a/tests/test_session.py b/tests/test_session.py index 924f0cf6..09e0b1cc 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -35,6 +35,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_load_model(self): sesh = Session() - self.assertTrue(sesh.load_model(DummyImageToImageModel.model_id)) + success = sesh.load_model(DummyImageToImageModel.model_id) + self.assertTrue(success) self.assertTrue('model_id' in sesh.models.keys()) self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel) \ No newline at end of file -- GitLab