From 8ea565009b95b7995c27f12beeedb94706a08f96 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 3 Nov 2023 11:09:15 +0100
Subject: [PATCH] Anything that creates a model returns the model_id as a
 string

---
 model_server/session.py | 10 +++++-----
 tests/test_api.py       |  2 +-
 tests/test_session.py   | 12 ++++++------
 3 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/model_server/session.py b/model_server/session.py
index c8996b39..f9ef3d42 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -98,7 +98,7 @@ class Session(object):
         Load an instance of a given model class and attach to this session's model registry
         :param ModelClass: subclass of Model
         :param params: optional parameters that are passed to the model's construct
-        :return: dictionary that describes all currently loaded models
+        :return: model_id of loaded model
         """
         mi = ModelClass(params=params)
         assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
@@ -116,7 +116,7 @@ class Session(object):
             'params': params
         }
         self.log_event(f'Loaded model {key}')
-        return {key: self.models[key]}
+        return key
 
     def describe_loaded_models(self) -> dict:
         return {
@@ -129,7 +129,7 @@ class Session(object):
 
     def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict:
         """
-        Returns first instance of loaded model where key and value match with .params field, or None
+        Returns model_id of first model where key and value match with .params field, or None
         :param is_path: uses platform-independent path comparison if True
         """
 
@@ -137,10 +137,10 @@ class Session(object):
         for mid, det in models.items():
             if is_path:
                 if Path(det.get('params').get(key)) == Path(value):
-                    return {mid: det}
+                    return mid
             else:
                 if det.get('params').get(key) == value:
-                    return {mid: det}
+                    return mid
         return None
 
     def restart(self, **kwargs):
diff --git a/tests/test_api.py b/tests/test_api.py
index c81c2d02..e6c0f799 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -65,7 +65,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         resp_load = requests.put(
             self.uri + f'models/dummy/load',
         )
-        model_id = list(resp_load.json()['model_id'].keys())[0]
+        model_id = resp_load.json()['model_id']
         self.assertEqual(resp_load.status_code, 200, resp_load.json())
         resp_list = requests.get(self.uri + 'models')
         self.assertEqual(resp_list.status_code, 200)
diff --git a/tests/test_session.py b/tests/test_session.py
index 555e020b..e3c93192 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -97,19 +97,19 @@ class TestGetSessionObject(unittest.TestCase):
         # load a second model and confirm that the first is locatable by its param entry
         p2 = {'p2': 'def'}
         sesh.load_model(MC, params=p2)
-        find_kv = sesh.find_param_in_loaded_models('p1', 'abc')
-        self.assertEqual(len(find_kv), 1)
-        self.assertEqual(find_kv[mid]['params'], p1)
+        find_mid = sesh.find_param_in_loaded_models('p1', 'abc')
+        self.assertEqual(mid, find_mid)
+        self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1)
 
     def test_session_finds_existing_model_with_different_path_formats(self):
         sesh = Session()
         MC = DummyImageToImageModel
         p1 = {'path': 'c:\\windows\\dummy.pa'}
         p2 = {'path': 'c:/windows/dummy.pa'}
-        sesh.load_model(MC, params=p1)
+        mid = sesh.load_model(MC, params=p1)
         assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
-        find_kv = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
-        self.assertEqual(len(find_kv), 1)
+        find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
+        self.assertEqual(mid, find_mid)
 
     def test_change_output_path(self):
         import pathlib
-- 
GitLab