Skip to content
Snippets Groups Projects
Commit a62aa5a5 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Session now recursively searches for all subclasses of Model when loading a model

parent 4d2f4d7a
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,17 @@ class Model(ABC):
raise CouldNotLoadModelError()
return None
@classmethod
def get_all_subclasses(cls):
"""
Recursively find all subclasses of Model
:return: set of all subclasses of Model
"""
def get_all_subclasses_of(cc):
return set(cc.__subclasses__()).union(
[s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)])
return get_all_subclasses_of(cls)
@abstractmethod
def load(self):
"""
......
......@@ -67,12 +67,16 @@ class Session(object):
:param model_id:
:return: True if model successfully loaded, False if not
"""
for mc in Model.__subclasses__():
models = Model.get_all_subclasses()
for mc in models:
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc()
assert mi.loaded
self.models['model_id'] = mi
return True
raise CouldNotFindModelError(
f'Could not find {model_id} in:\n{models}',
)
return False
def describe_models(self) -> dict:
......@@ -87,5 +91,8 @@ class Session(object):
class Error(Exception):
pass
class CouldNotFindModelError(Error):
pass
class InferenceRecordError(Error):
pass
\ No newline at end of file
......@@ -34,10 +34,11 @@ class TestApiFromAutomatedClient(unittest.TestCase):
def test_load_model(self):
model_id = DummyImageToImageModel.model_id
resp = requests.get(self.uri + f'models/{model_id}/load')
self.assertEqual(resp.status_code, 200)
loaded = requests.get(self.uri + 'models')
self.assertEqual(loaded.content, b'{"model_id":"DummyImageToImageModel"}')
resp_load = requests.get(self.uri + f'models/{model_id}/load')
self.assertEqual(resp_load.status_code, 200)
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
self.assertEqual(resp_list.content, b'{"model_id":"DummyImageToImageModel"}')
def test_i2i_inference_errors_model_not_sound(self):
model_id = 'not_a_real_model'
......
import unittest
from conf.testing import czifile
from model_server.image import CziImageFileAccessor
from model_server.model import DummyImageToImageModel, CouldNotLoadModelError
from model_server.model import DummyImageToImageModel, CouldNotLoadModelError, Model
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
......@@ -48,4 +48,9 @@ class TestCziImageFileAccess(unittest.TestCase):
img[0, 0],
0,
'First pixel is not black as expected'
)
\ No newline at end of file
)
def test_find_subclasses_recursively(self):
sc = DummyImageToImageModel
scs = Model.get_all_subclasses()
self.assertIn(DummyImageToImageModel, scs)
\ No newline at end of file
......@@ -35,6 +35,7 @@ class TestGetSessionObject(unittest.TestCase):
def test_session_load_model(self):
sesh = Session()
self.assertTrue(sesh.load_model(DummyImageToImageModel.model_id))
success = sesh.load_model(DummyImageToImageModel.model_id)
self.assertTrue(success)
self.assertTrue('model_id' in sesh.models.keys())
self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment