From e147a54f3e84a0a721f17dce733f0b43a6ed2be6 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 31 Aug 2023 15:50:18 +0200 Subject: [PATCH] Session returns JSON-compliant model list --- api.py | 4 ++-- model_server/session.py | 9 +++++++-- tests/test_api.py | 8 +++----- tests/test_session.py | 7 ++++--- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/api.py b/api.py index 9518066f..ad40f5ef 100644 --- a/api.py +++ b/api.py @@ -23,8 +23,8 @@ def list_active_models(): return session.describe_loaded_models() @app.put('/models/dummy/load/') -def load_dummy_model(params: str = None) -> dict: - return session.load_model(DummyImageToImageModel, params) +def load_dummy_model() -> dict: + return session.load_model(DummyImageToImageModel) @app.put('/models/ilastik/pixel_classification/load/') def load_ilastik_pixel_classification_model(params: str) -> dict: diff --git a/model_server/session.py b/model_server/session.py index 0ce71e13..6db22003 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -85,8 +85,13 @@ class Session(object): return self.describe_loaded_models() def describe_loaded_models(self) -> dict: - # TODO: explictly make this JSON-compatible - return self.models + return { + k: { + 'class': self.models[k]['object'].__class__.__name__, + 'params': self.models[k]['params'], + } + for k in self.models.keys() + } def restart(self): self.__init__() diff --git a/tests/test_api.py b/tests/test_api.py index 4e7c4819..d4f7d266 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -47,17 +47,15 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.content, b'{}') def test_load_dummy_model(self): - model_id = DummyImageToImageModel.model_id + model_key = DummyImageToImageModel.__name__ + '_00' resp_load = requests.put( self.uri + f'models/dummy/load', - # params={'misc': {'d': 'e'}} ) self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') - return model_id + self.assertEqual(rj[model_key]['class'], 'DummyImageToImageModel') def test_respond_with_error_when_invalid_model_loaded(self): model_id = 'not_a_real_model' @@ -69,7 +67,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): print(resp.content) def test_respond_with_error_when_invalid_filepath_requested(self): - model_id = self.test_load_dummy_model() + # model_id = self.test_load_dummy_model() resp = requests.put( self.uri + f'i2i/infer/', params={ diff --git a/tests/test_session.py b/tests/test_session.py index 2cfead8e..6a15c68b 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -44,12 +44,13 @@ class TestGetSessionObject(unittest.TestCase): success = sesh.load_model(MC) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() + print(loaded_models) self.assertTrue( (MC.__name__ + '_00') in loaded_models.keys() ) self.assertEqual( - loaded_models[MC.__name__ + '_00']['object'].__class__, - MC + loaded_models[MC.__name__ + '_00']['class'], + MC.__name__ ) def test_session_loads_second_instance_of_same_model(self): @@ -69,4 +70,4 @@ class TestGetSessionObject(unittest.TestCase): success = sesh.load_model(MC, params=p) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() - self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p) \ No newline at end of file + self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p) -- GitLab