From d331f923a3ffe1fbf192d10b7956e35c11229ae5 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 28 Aug 2023 11:33:20 +0200 Subject: [PATCH] Changed superclass of dummy model --- api.py | 7 ++++--- model_server/model.py | 6 ++++-- model_server/session.py | 7 ++++++- tests/test_api.py | 8 ++++++-- tests/test_session.py | 6 ++++++ 5 files changed, 26 insertions(+), 8 deletions(-) diff --git a/api.py b/api.py index 0640ad67..5a5f25aa 100644 --- a/api.py +++ b/api.py @@ -18,16 +18,17 @@ def read_root(): @app.get('/models') def list_active_models(): - return session.models # TODO: include model type too + return session.describe_models() @app.get('/models/{model_id}/load/') -def load_model(model_id: str, project_file: Path) -> Path: # does API autoencode path as JSON? +def load_model(model_id: str) -> dict: if model_id in session.models.keys(): raise HTTPException( status_code=409, detail=f'Model with id {model_id} has already been loaded' ) - session + session.load_model(model_id) + return session.describe_models() @app.post('/i2i/infer/{model_id}') # image file in, image file out def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: diff --git a/model_server/model.py b/model_server/model.py index 6eafdf3d..563b85d9 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -20,6 +20,7 @@ class Model(ABC): else: self.loaded = False raise CouldNotLoadModelError() + return None @abstractmethod def load(self): @@ -44,6 +45,7 @@ class Model(ABC): def reload(self): self.load() + class ImageToImageModel(Model): def __init__(self, **kwargs): @@ -51,7 +53,7 @@ class ImageToImageModel(Model): Abstract class for models that receive an image and return an image of the same size :param kwargs: variable length keyword arguments """ - return super(**kwargs) + return super().__init__(**kwargs) @abstractmethod def infer(self, img, channel=None) -> (np.ndarray, dict): @@ -60,7 +62,7 @@ class ImageToImageModel(Model): class IlastikImageToImageModel(ImageToImageModel): pass -class DummyImageToImageModel(Model): +class DummyImageToImageModel(ImageToImageModel): model_id = 'dummy_make_white_square' diff --git a/model_server/session.py b/model_server/session.py index 17958bb9..3f4fa30a 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -71,10 +71,15 @@ class Session(object): if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: mi = mc() assert mi.loaded - self.models.append(mi) + self.models['model_id'] = mi return True return False + def describe_models(self) -> dict: + return { + k: self.models[k].__class__.__name__ + for k in self.models.keys() + } def restart(self): self.__init__() diff --git a/tests/test_api.py b/tests/test_api.py index f7c778c0..4071c023 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -29,11 +29,15 @@ class TestApiFromAutomatedClient(unittest.TestCase): def test_list_empty_loaded_models(self): resp = requests.get(self.uri + 'models') - print(resp.content) self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.content, b'{}') def test_load_model(self): - resp = requests.get(self.uri + 'load_') + model_id = DummyImageToImageModel.model_id + resp = requests.get(self.uri + f'models/{model_id}/load') + self.assertEqual(resp.status_code, 200) + loaded = requests.get(self.uri + 'models') + self.assertEqual(loaded.content, b'{"model_id":"DummyImageToImageModel"}') def test_i2i_inference_errors_model_not_sound(self): model_id = 'not_a_real_model' diff --git a/tests/test_session.py b/tests/test_session.py index 63aaea5b..924f0cf6 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,4 +1,5 @@ import unittest +from model_server.model import DummyImageToImageModel from model_server.session import Session class TestGetSessionObject(unittest.TestCase): @@ -32,3 +33,8 @@ class TestGetSessionObject(unittest.TestCase): do = json.load(fh) self.assertEqual(di.dict(), do, 'Manifest record is not correct') + def test_session_load_model(self): + sesh = Session() + self.assertTrue(sesh.load_model(DummyImageToImageModel.model_id)) + self.assertTrue('model_id' in sesh.models.keys()) + self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel) \ No newline at end of file -- GitLab