diff --git a/api.py b/api.py index 9518066f46722be53132db03c4508b8d156f6632..ad40f5ef7edb8bd6c7648b3b62c0ffab3dff732e 100644 --- a/api.py +++ b/api.py @@ -23,8 +23,8 @@ def list_active_models(): return session.describe_loaded_models() @app.put('/models/dummy/load/') -def load_dummy_model(params: str = None) -> dict: - return session.load_model(DummyImageToImageModel, params) +def load_dummy_model() -> dict: + return session.load_model(DummyImageToImageModel) @app.put('/models/ilastik/pixel_classification/load/') def load_ilastik_pixel_classification_model(params: str) -> dict: diff --git a/model_server/session.py b/model_server/session.py index 0ce71e1388dcf153ee4e24571a410c801ade691c..6db22003c7e2bfeb4e45bdde566eb7492525bebb 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -85,8 +85,13 @@ class Session(object): return self.describe_loaded_models() def describe_loaded_models(self) -> dict: - # TODO: explictly make this JSON-compatible - return self.models + return { + k: { + 'class': self.models[k]['object'].__class__.__name__, + 'params': self.models[k]['params'], + } + for k in self.models.keys() + } def restart(self): self.__init__() diff --git a/tests/test_api.py b/tests/test_api.py index 4e7c4819482c21a2fef2e9bf563747d92f21dc97..d4f7d266e68d441f8a7bb566d40e8758055baa2d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -47,17 +47,15 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(resp.content, b'{}') def test_load_dummy_model(self): - model_id = DummyImageToImageModel.model_id + model_key = DummyImageToImageModel.__name__ + '_00' resp_load = requests.put( self.uri + f'models/dummy/load', - # params={'misc': {'d': 'e'}} ) self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') - return model_id + self.assertEqual(rj[model_key]['class'], 'DummyImageToImageModel') def test_respond_with_error_when_invalid_model_loaded(self): model_id = 'not_a_real_model' @@ -69,7 +67,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): print(resp.content) def test_respond_with_error_when_invalid_filepath_requested(self): - model_id = self.test_load_dummy_model() + # model_id = self.test_load_dummy_model() resp = requests.put( self.uri + f'i2i/infer/', params={ diff --git a/tests/test_session.py b/tests/test_session.py index 2cfead8e44a970b9a89923ee0452f7c3c55027b3..6a15c68be361d5e9eab2a668c5b186d359dc1047 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -44,12 +44,13 @@ class TestGetSessionObject(unittest.TestCase): success = sesh.load_model(MC) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() + print(loaded_models) self.assertTrue( (MC.__name__ + '_00') in loaded_models.keys() ) self.assertEqual( - loaded_models[MC.__name__ + '_00']['object'].__class__, - MC + loaded_models[MC.__name__ + '_00']['class'], + MC.__name__ ) def test_session_loads_second_instance_of_same_model(self): @@ -69,4 +70,4 @@ class TestGetSessionObject(unittest.TestCase): success = sesh.load_model(MC, params=p) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() - self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p) \ No newline at end of file + self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p)