From ba1bc5d16fe709c7345e453946b7361e620fac0b Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 3 Nov 2023 11:21:22 +0100
Subject: [PATCH] Validated non-duplication when loading ilastik model with
 same project file name, irrespective of path formatting

---
 extensions/ilastik/router.py             | 13 ++++---------
 extensions/ilastik/tests/test_ilastik.py | 16 +++++++++++-----
 2 files changed, 15 insertions(+), 14 deletions(-)

diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py
index ba01d1b9..0503ebaa 100644
--- a/extensions/ilastik/router.py
+++ b/extensions/ilastik/router.py
@@ -20,25 +20,20 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file:
     :param model_class:
     :param project_file: (*.ilp) ilastik project filename
     :param duplicate: load another instance of the same project file if True; return existing one if false
-    :return: dictionary with single key describing model's ID
+    :return: dict containing model's ID
     """
     if not duplicate:
-        existing_model = session.find_param_in_loaded_models('project_file', project_file)
+        existing_model = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
         if existing_model is not None:
             return existing_model
     try:
-        result = {
-            'model_id': session.load_model(
-                model_class,
-                {'project_file': project_file}
-            )
-        }
+        result = session.load_model(model_class, {'project_file': project_file})
     except (FileNotFoundError, ParameterExpectedError):
         raise HTTPException(
             status_code=404,
             detail=f'Could not load project file {project_file}',
         )
-    return result
+    return {'model_id': result}
 
 @router.put('/px/load/')
 def load_px_model(project_file: str, duplicate: bool = True) -> dict:
diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py
index 5eb4076b..6e1bb097 100644
--- a/extensions/ilastik/tests/test_ilastik.py
+++ b/extensions/ilastik/tests/test_ilastik.py
@@ -124,9 +124,8 @@ class TestIlastikOverApi(TestServerBaseClass):
             self.uri + 'ilastik/px/load/',
             params={'project_file': str(ilastik_classifiers['px'])},
         )
-        model_id = resp_load.json()['model_id']
-
         self.assertEqual(resp_load.status_code, 200, resp_load.json())
+        model_id = resp_load.json()['model_id']
         resp_list = requests.get(self.uri + 'models')
         self.assertEqual(resp_list.status_code, 200)
         rj = resp_list.json()
@@ -157,30 +156,37 @@ class TestIlastikOverApi(TestServerBaseClass):
         self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
 
     def test_no_duplicate_model_with_different_path_formats(self):
-        resp_restart = requests.get(self.uri + 'restart')
+        requests.get(self.uri + 'restart')
         resp_list_1 = requests.get(self.uri + 'models').json()
         self.assertEqual(len(resp_list_1), 0)
         ilp = ilastik_classifiers['px']
+
+        # create and validate two copies of the same pathname with different string formats
         ilp_win = str(pathlib.PureWindowsPath(ilp))
         self.assertGreater(ilp_win.count('\\'), 0) # i.e. contains backslashes
         self.assertEqual(ilp_win.count('/'), 0)
         ilp_posx = ilastik_classifiers['px'].as_posix()
         self.assertGreater(ilp_posx.count('/'), 0)
         self.assertEqual(ilp_posx.count('\\'), 0)
-        resp_load_1 = requests.put(
+        self.assertEqual(pathlib.Path(ilp_win), pathlib.Path(ilp_posx))
+
+        # load models with these paths
+        requests.put(
             self.uri + 'ilastik/px/load/',
             params={
                 'project_file': ilp_win,
                 'duplicate': False,
             },
         )
-        resp_load_2 = requests.put(
+        requests.put(
             self.uri + 'ilastik/px/load/',
             params={
                 'project_file': ilp_posx,
                 'duplicate': False,
             },
         )
+
+        # assert that only one copy of the model is loaded
         resp_list_2 = requests.get(self.uri + 'models').json()
         print(resp_list_2)
         self.assertEqual(len(resp_list_2), 1)
-- 
GitLab