From aaa18478b03dfe2640205c88131d1b449af4248c Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 3 Nov 2023 11:31:07 +0100
Subject: [PATCH] Model-loading always returns {'model_id': _} when successfuly

---
 extensions/ilastik/router.py             | 6 +++---
 extensions/ilastik/tests/test_ilastik.py | 5 +++--
 model_server/session.py                  | 2 +-
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py
index 0503ebaa..14ec3258 100644
--- a/extensions/ilastik/router.py
+++ b/extensions/ilastik/router.py
@@ -23,9 +23,9 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file:
     :return: dict containing model's ID
     """
     if not duplicate:
-        existing_model = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
-        if existing_model is not None:
-            return existing_model
+        existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
+        if existing_model_id is not None:
+            return {'model_id': existing_model_id}
     try:
         result = session.load_model(model_class, {'project_file': project_file})
     except (FileNotFoundError, ParameterExpectedError):
diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py
index 6e1bb097..79cbe3b4 100644
--- a/extensions/ilastik/tests/test_ilastik.py
+++ b/extensions/ilastik/tests/test_ilastik.py
@@ -171,20 +171,21 @@ class TestIlastikOverApi(TestServerBaseClass):
         self.assertEqual(pathlib.Path(ilp_win), pathlib.Path(ilp_posx))
 
         # load models with these paths
-        requests.put(
+        resp1 = requests.put(
             self.uri + 'ilastik/px/load/',
             params={
                 'project_file': ilp_win,
                 'duplicate': False,
             },
         )
-        requests.put(
+        resp2 = requests.put(
             self.uri + 'ilastik/px/load/',
             params={
                 'project_file': ilp_posx,
                 'duplicate': False,
             },
         )
+        self.assertEqual(resp1.json(), resp2.json())
 
         # assert that only one copy of the model is loaded
         resp_list_2 = requests.get(self.uri + 'models').json()
diff --git a/model_server/session.py b/model_server/session.py
index f9ef3d42..061711a5 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -127,7 +127,7 @@ class Session(object):
             for k in self.models.keys()
         }
 
-    def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict:
+    def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> str:
         """
         Returns model_id of first model where key and value match with .params field, or None
         :param is_path: uses platform-independent path comparison if True
-- 
GitLab