From 3daaf0baf042d90a339b5d882c607f7ca6448583 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 3 Nov 2023 10:31:41 +0100
Subject: [PATCH] Test covers problem with models being duplicated because of
 different path specifications; not yet fixed at API level

---
 extensions/ilastik/tests/test_ilastik.py | 30 ++++++++++++++++++++++++
 model_server/session.py                  | 12 +++++++---
 tests/test_session.py                    | 11 +++++++++
 3 files changed, 50 insertions(+), 3 deletions(-)

diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py
index d2498af2..5eb4076b 100644
--- a/extensions/ilastik/tests/test_ilastik.py
+++ b/extensions/ilastik/tests/test_ilastik.py
@@ -1,3 +1,4 @@
+import pathlib
 import requests
 import unittest
 
@@ -155,6 +156,35 @@ class TestIlastikOverApi(TestServerBaseClass):
         resp_list_3rd = requests.get(self.uri + 'models').json()
         self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
 
+    def test_no_duplicate_model_with_different_path_formats(self):
+        resp_restart = requests.get(self.uri + 'restart')
+        resp_list_1 = requests.get(self.uri + 'models').json()
+        self.assertEqual(len(resp_list_1), 0)
+        ilp = ilastik_classifiers['px']
+        ilp_win = str(pathlib.PureWindowsPath(ilp))
+        self.assertGreater(ilp_win.count('\\'), 0) # i.e. contains backslashes
+        self.assertEqual(ilp_win.count('/'), 0)
+        ilp_posx = ilastik_classifiers['px'].as_posix()
+        self.assertGreater(ilp_posx.count('/'), 0)
+        self.assertEqual(ilp_posx.count('\\'), 0)
+        resp_load_1 = requests.put(
+            self.uri + 'ilastik/px/load/',
+            params={
+                'project_file': ilp_win,
+                'duplicate': False,
+            },
+        )
+        resp_load_2 = requests.put(
+            self.uri + 'ilastik/px/load/',
+            params={
+                'project_file': ilp_posx,
+                'duplicate': False,
+            },
+        )
+        resp_list_2 = requests.get(self.uri + 'models').json()
+        print(resp_list_2)
+        self.assertEqual(len(resp_list_2), 1)
+
 
     def test_load_ilastik_pxmap_to_obj_model(self):
         resp_load = requests.put(
diff --git a/model_server/session.py b/model_server/session.py
index 73b000cf..c8996b39 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -127,14 +127,20 @@ class Session(object):
             for k in self.models.keys()
         }
 
-    def find_param_in_loaded_models(self, key: str, value: str) -> dict:
+    def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict:
         """
         Returns first instance of loaded model where key and value match with .params field, or None
+        :param is_path: uses platform-independent path comparison if True
         """
+
         models = self.describe_loaded_models()
         for mid, det in models.items():
-            if det.get('params').get(key) == value:
-                return {mid: det}
+            if is_path:
+                if Path(det.get('params').get(key)) == Path(value):
+                    return {mid: det}
+            else:
+                if det.get('params').get(key) == value:
+                    return {mid: det}
         return None
 
     def restart(self, **kwargs):
diff --git a/tests/test_session.py b/tests/test_session.py
index bdad0486..555e020b 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -1,3 +1,4 @@
+import pathlib
 import unittest
 from model_server.models import DummyImageToImageModel
 from model_server.session import Session
@@ -100,6 +101,16 @@ class TestGetSessionObject(unittest.TestCase):
         self.assertEqual(len(find_kv), 1)
         self.assertEqual(find_kv[mid]['params'], p1)
 
+    def test_session_finds_existing_model_with_different_path_formats(self):
+        sesh = Session()
+        MC = DummyImageToImageModel
+        p1 = {'path': 'c:\\windows\\dummy.pa'}
+        p2 = {'path': 'c:/windows/dummy.pa'}
+        sesh.load_model(MC, params=p1)
+        assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
+        find_kv = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
+        self.assertEqual(len(find_kv), 1)
+
     def test_change_output_path(self):
         import pathlib
         sesh = Session()
-- 
GitLab