diff --git a/api.py b/api.py index 0640ad67ecfc43bf7b5ea5e263dff3c613822d40..5a5f25aaa0572aa520be494d097a3a7bcafd0582 100644 --- a/api.py +++ b/api.py @@ -18,16 +18,17 @@ def read_root(): @app.get('/models') def list_active_models(): - return session.models # TODO: include model type too + return session.describe_models() @app.get('/models/{model_id}/load/') -def load_model(model_id: str, project_file: Path) -> Path: # does API autoencode path as JSON? +def load_model(model_id: str) -> dict: if model_id in session.models.keys(): raise HTTPException( status_code=409, detail=f'Model with id {model_id} has already been loaded' ) - session + session.load_model(model_id) + return session.describe_models() @app.post('/i2i/infer/{model_id}') # image file in, image file out def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: diff --git a/model_server/model.py b/model_server/model.py index 6eafdf3d084c93d1a146ccb0642d91c13fa621d6..563b85d9834623f234bc1526157c7b1d331e6794 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -20,6 +20,7 @@ class Model(ABC): else: self.loaded = False raise CouldNotLoadModelError() + return None @abstractmethod def load(self): @@ -44,6 +45,7 @@ class Model(ABC): def reload(self): self.load() + class ImageToImageModel(Model): def __init__(self, **kwargs): @@ -51,7 +53,7 @@ class ImageToImageModel(Model): Abstract class for models that receive an image and return an image of the same size :param kwargs: variable length keyword arguments """ - return super(**kwargs) + return super().__init__(**kwargs) @abstractmethod def infer(self, img, channel=None) -> (np.ndarray, dict): @@ -60,7 +62,7 @@ class ImageToImageModel(Model): class IlastikImageToImageModel(ImageToImageModel): pass -class DummyImageToImageModel(Model): +class DummyImageToImageModel(ImageToImageModel): model_id = 'dummy_make_white_square' diff --git a/model_server/session.py b/model_server/session.py index 17958bb9ed0c881fb4974350c5ce86f27180ffba..3f4fa30a561551b85e5501cc3c369703786153e6 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -71,10 +71,15 @@ class Session(object): if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: mi = mc() assert mi.loaded - self.models.append(mi) + self.models['model_id'] = mi return True return False + def describe_models(self) -> dict: + return { + k: self.models[k].__class__.__name__ + for k in self.models.keys() + } def restart(self): self.__init__() diff --git a/tests/test_api.py b/tests/test_api.py index f7c778c06658922345de248be5a978ad739eff6f..4071c023b454b0cf5a38db4cfcbc97a641239c22 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -29,11 +29,15 @@ class TestApiFromAutomatedClient(unittest.TestCase): def test_list_empty_loaded_models(self): resp = requests.get(self.uri + 'models') - print(resp.content) self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.content, b'{}') def test_load_model(self): - resp = requests.get(self.uri + 'load_') + model_id = DummyImageToImageModel.model_id + resp = requests.get(self.uri + f'models/{model_id}/load') + self.assertEqual(resp.status_code, 200) + loaded = requests.get(self.uri + 'models') + self.assertEqual(loaded.content, b'{"model_id":"DummyImageToImageModel"}') def test_i2i_inference_errors_model_not_sound(self): model_id = 'not_a_real_model' diff --git a/tests/test_session.py b/tests/test_session.py index 63aaea5bf7f76a282bc22de2fb7ff1617cb5b20b..924f0cf64812cfcaae4ec1cee18c88ca3dab90ae 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,4 +1,5 @@ import unittest +from model_server.model import DummyImageToImageModel from model_server.session import Session class TestGetSessionObject(unittest.TestCase): @@ -32,3 +33,8 @@ class TestGetSessionObject(unittest.TestCase): do = json.load(fh) self.assertEqual(di.dict(), do, 'Manifest record is not correct') + def test_session_load_model(self): + sesh = Session() + self.assertTrue(sesh.load_model(DummyImageToImageModel.model_id)) + self.assertTrue('model_id' in sesh.models.keys()) + self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel) \ No newline at end of file