Skip to content
Snippets Groups Projects
Commit a62aa5a5 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Session now recursively searches for all subclasses of Model when loading a model

parent 4d2f4d7a
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,17 @@ class Model(ABC): ...@@ -22,6 +22,17 @@ class Model(ABC):
raise CouldNotLoadModelError() raise CouldNotLoadModelError()
return None return None
@classmethod
def get_all_subclasses(cls):
"""
Recursively find all subclasses of Model
:return: set of all subclasses of Model
"""
def get_all_subclasses_of(cc):
return set(cc.__subclasses__()).union(
[s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)])
return get_all_subclasses_of(cls)
@abstractmethod @abstractmethod
def load(self): def load(self):
""" """
......
...@@ -67,12 +67,16 @@ class Session(object): ...@@ -67,12 +67,16 @@ class Session(object):
:param model_id: :param model_id:
:return: True if model successfully loaded, False if not :return: True if model successfully loaded, False if not
""" """
for mc in Model.__subclasses__(): models = Model.get_all_subclasses()
for mc in models:
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc() mi = mc()
assert mi.loaded assert mi.loaded
self.models['model_id'] = mi self.models['model_id'] = mi
return True return True
raise CouldNotFindModelError(
f'Could not find {model_id} in:\n{models}',
)
return False return False
def describe_models(self) -> dict: def describe_models(self) -> dict:
...@@ -87,5 +91,8 @@ class Session(object): ...@@ -87,5 +91,8 @@ class Session(object):
class Error(Exception): class Error(Exception):
pass pass
class CouldNotFindModelError(Error):
pass
class InferenceRecordError(Error): class InferenceRecordError(Error):
pass pass
\ No newline at end of file
...@@ -34,10 +34,11 @@ class TestApiFromAutomatedClient(unittest.TestCase): ...@@ -34,10 +34,11 @@ class TestApiFromAutomatedClient(unittest.TestCase):
def test_load_model(self): def test_load_model(self):
model_id = DummyImageToImageModel.model_id model_id = DummyImageToImageModel.model_id
resp = requests.get(self.uri + f'models/{model_id}/load') resp_load = requests.get(self.uri + f'models/{model_id}/load')
self.assertEqual(resp.status_code, 200) self.assertEqual(resp_load.status_code, 200)
loaded = requests.get(self.uri + 'models') resp_list = requests.get(self.uri + 'models')
self.assertEqual(loaded.content, b'{"model_id":"DummyImageToImageModel"}') self.assertEqual(resp_list.status_code, 200)
self.assertEqual(resp_list.content, b'{"model_id":"DummyImageToImageModel"}')
def test_i2i_inference_errors_model_not_sound(self): def test_i2i_inference_errors_model_not_sound(self):
model_id = 'not_a_real_model' model_id = 'not_a_real_model'
......
import unittest import unittest
from conf.testing import czifile from conf.testing import czifile
from model_server.image import CziImageFileAccessor from model_server.image import CziImageFileAccessor
from model_server.model import DummyImageToImageModel, CouldNotLoadModelError from model_server.model import DummyImageToImageModel, CouldNotLoadModelError, Model
class TestCziImageFileAccess(unittest.TestCase): class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
...@@ -48,4 +48,9 @@ class TestCziImageFileAccess(unittest.TestCase): ...@@ -48,4 +48,9 @@ class TestCziImageFileAccess(unittest.TestCase):
img[0, 0], img[0, 0],
0, 0,
'First pixel is not black as expected' 'First pixel is not black as expected'
) )
\ No newline at end of file
def test_find_subclasses_recursively(self):
sc = DummyImageToImageModel
scs = Model.get_all_subclasses()
self.assertIn(DummyImageToImageModel, scs)
\ No newline at end of file
...@@ -35,6 +35,7 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -35,6 +35,7 @@ class TestGetSessionObject(unittest.TestCase):
def test_session_load_model(self): def test_session_load_model(self):
sesh = Session() sesh = Session()
self.assertTrue(sesh.load_model(DummyImageToImageModel.model_id)) success = sesh.load_model(DummyImageToImageModel.model_id)
self.assertTrue(success)
self.assertTrue('model_id' in sesh.models.keys()) self.assertTrue('model_id' in sesh.models.keys())
self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel) self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment