diff --git a/model_server/model.py b/model_server/model.py index 563b85d9834623f234bc1526157c7b1d331e6794..7c9d84e1184640703a6a0574d1f4c18d6b19cbac 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -22,6 +22,17 @@ class Model(ABC): raise CouldNotLoadModelError() return None + @classmethod + def get_all_subclasses(cls): + """ + Recursively find all subclasses of Model + :return: set of all subclasses of Model + """ + def get_all_subclasses_of(cc): + return set(cc.__subclasses__()).union( + [s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)]) + return get_all_subclasses_of(cls) + @abstractmethod def load(self): """ diff --git a/model_server/session.py b/model_server/session.py index 3f4fa30a561551b85e5501cc3c369703786153e6..66a3ac119e020a0f143d9556c7c23fffdb8392ac 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -67,12 +67,16 @@ class Session(object): :param model_id: :return: True if model successfully loaded, False if not """ - for mc in Model.__subclasses__(): + models = Model.get_all_subclasses() + for mc in models: if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: mi = mc() assert mi.loaded self.models['model_id'] = mi return True + raise CouldNotFindModelError( + f'Could not find {model_id} in:\n{models}', + ) return False def describe_models(self) -> dict: @@ -87,5 +91,8 @@ class Session(object): class Error(Exception): pass +class CouldNotFindModelError(Error): + pass + class InferenceRecordError(Error): pass \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index 4071c023b454b0cf5a38db4cfcbc97a641239c22..17afe2e3a9645eaebfacc4bfdad33098374b6a07 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -34,10 +34,11 @@ class TestApiFromAutomatedClient(unittest.TestCase): def test_load_model(self): model_id = DummyImageToImageModel.model_id - resp = requests.get(self.uri + f'models/{model_id}/load') - self.assertEqual(resp.status_code, 200) - loaded = requests.get(self.uri + 'models') - self.assertEqual(loaded.content, b'{"model_id":"DummyImageToImageModel"}') + resp_load = requests.get(self.uri + f'models/{model_id}/load') + self.assertEqual(resp_load.status_code, 200) + resp_list = requests.get(self.uri + 'models') + self.assertEqual(resp_list.status_code, 200) + self.assertEqual(resp_list.content, b'{"model_id":"DummyImageToImageModel"}') def test_i2i_inference_errors_model_not_sound(self): model_id = 'not_a_real_model' diff --git a/tests/test_model.py b/tests/test_model.py index aff9db8639b8380f1ab89cf7f8f061ce7a1ea010..0ca0938e2ef541c016eba701a5bfc45eb1d5712a 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,7 +1,7 @@ import unittest from conf.testing import czifile from model_server.image import CziImageFileAccessor -from model_server.model import DummyImageToImageModel, CouldNotLoadModelError +from model_server.model import DummyImageToImageModel, CouldNotLoadModelError, Model class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: @@ -48,4 +48,9 @@ class TestCziImageFileAccess(unittest.TestCase): img[0, 0], 0, 'First pixel is not black as expected' - ) \ No newline at end of file + ) + + def test_find_subclasses_recursively(self): + sc = DummyImageToImageModel + scs = Model.get_all_subclasses() + self.assertIn(DummyImageToImageModel, scs) \ No newline at end of file diff --git a/tests/test_session.py b/tests/test_session.py index 924f0cf64812cfcaae4ec1cee18c88ca3dab90ae..09e0b1ccc110c54caf57adc3449ba72dc386339d 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -35,6 +35,7 @@ class TestGetSessionObject(unittest.TestCase): def test_session_load_model(self): sesh = Session() - self.assertTrue(sesh.load_model(DummyImageToImageModel.model_id)) + success = sesh.load_model(DummyImageToImageModel.model_id) + self.assertTrue(success) self.assertTrue('model_id' in sesh.models.keys()) self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel) \ No newline at end of file