From 8ea565009b95b7995c27f12beeedb94706a08f96 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 3 Nov 2023 11:09:15 +0100 Subject: [PATCH] Anything that creates a model returns the model_id as a string --- model_server/session.py | 10 +++++----- tests/test_api.py | 2 +- tests/test_session.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/model_server/session.py b/model_server/session.py index c8996b39..f9ef3d42 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -98,7 +98,7 @@ class Session(object): Load an instance of a given model class and attach to this session's model registry :param ModelClass: subclass of Model :param params: optional parameters that are passed to the model's construct - :return: dictionary that describes all currently loaded models + :return: model_id of loaded model """ mi = ModelClass(params=params) assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' @@ -116,7 +116,7 @@ class Session(object): 'params': params } self.log_event(f'Loaded model {key}') - return {key: self.models[key]} + return key def describe_loaded_models(self) -> dict: return { @@ -129,7 +129,7 @@ class Session(object): def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict: """ - Returns first instance of loaded model where key and value match with .params field, or None + Returns model_id of first model where key and value match with .params field, or None :param is_path: uses platform-independent path comparison if True """ @@ -137,10 +137,10 @@ class Session(object): for mid, det in models.items(): if is_path: if Path(det.get('params').get(key)) == Path(value): - return {mid: det} + return mid else: if det.get('params').get(key) == value: - return {mid: det} + return mid return None def restart(self, **kwargs): diff --git a/tests/test_api.py b/tests/test_api.py index c81c2d02..e6c0f799 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -65,7 +65,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass): resp_load = requests.put( self.uri + f'models/dummy/load', ) - model_id = list(resp_load.json()['model_id'].keys())[0] + model_id = resp_load.json()['model_id'] self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) diff --git a/tests/test_session.py b/tests/test_session.py index 555e020b..e3c93192 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -97,19 +97,19 @@ class TestGetSessionObject(unittest.TestCase): # load a second model and confirm that the first is locatable by its param entry p2 = {'p2': 'def'} sesh.load_model(MC, params=p2) - find_kv = sesh.find_param_in_loaded_models('p1', 'abc') - self.assertEqual(len(find_kv), 1) - self.assertEqual(find_kv[mid]['params'], p1) + find_mid = sesh.find_param_in_loaded_models('p1', 'abc') + self.assertEqual(mid, find_mid) + self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1) def test_session_finds_existing_model_with_different_path_formats(self): sesh = Session() MC = DummyImageToImageModel p1 = {'path': 'c:\\windows\\dummy.pa'} p2 = {'path': 'c:/windows/dummy.pa'} - sesh.load_model(MC, params=p1) + mid = sesh.load_model(MC, params=p1) assert pathlib.Path(p1['path']) == pathlib.Path(p2['path']) - find_kv = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) - self.assertEqual(len(find_kv), 1) + find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) + self.assertEqual(mid, find_mid) def test_change_output_path(self): import pathlib -- GitLab