From 3daaf0baf042d90a339b5d882c607f7ca6448583 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 3 Nov 2023 10:31:41 +0100 Subject: [PATCH] Test covers problem with models being duplicated because of different path specifications; not yet fixed at API level --- extensions/ilastik/tests/test_ilastik.py | 30 ++++++++++++++++++++++++ model_server/session.py | 12 +++++++--- tests/test_session.py | 11 +++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index d2498af2..5eb4076b 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -1,3 +1,4 @@ +import pathlib import requests import unittest @@ -155,6 +156,35 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_3rd = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) + def test_no_duplicate_model_with_different_path_formats(self): + resp_restart = requests.get(self.uri + 'restart') + resp_list_1 = requests.get(self.uri + 'models').json() + self.assertEqual(len(resp_list_1), 0) + ilp = ilastik_classifiers['px'] + ilp_win = str(pathlib.PureWindowsPath(ilp)) + self.assertGreater(ilp_win.count('\\'), 0) # i.e. contains backslashes + self.assertEqual(ilp_win.count('/'), 0) + ilp_posx = ilastik_classifiers['px'].as_posix() + self.assertGreater(ilp_posx.count('/'), 0) + self.assertEqual(ilp_posx.count('\\'), 0) + resp_load_1 = requests.put( + self.uri + 'ilastik/px/load/', + params={ + 'project_file': ilp_win, + 'duplicate': False, + }, + ) + resp_load_2 = requests.put( + self.uri + 'ilastik/px/load/', + params={ + 'project_file': ilp_posx, + 'duplicate': False, + }, + ) + resp_list_2 = requests.get(self.uri + 'models').json() + print(resp_list_2) + self.assertEqual(len(resp_list_2), 1) + def test_load_ilastik_pxmap_to_obj_model(self): resp_load = requests.put( diff --git a/model_server/session.py b/model_server/session.py index 73b000cf..c8996b39 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -127,14 +127,20 @@ class Session(object): for k in self.models.keys() } - def find_param_in_loaded_models(self, key: str, value: str) -> dict: + def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict: """ Returns first instance of loaded model where key and value match with .params field, or None + :param is_path: uses platform-independent path comparison if True """ + models = self.describe_loaded_models() for mid, det in models.items(): - if det.get('params').get(key) == value: - return {mid: det} + if is_path: + if Path(det.get('params').get(key)) == Path(value): + return {mid: det} + else: + if det.get('params').get(key) == value: + return {mid: det} return None def restart(self, **kwargs): diff --git a/tests/test_session.py b/tests/test_session.py index bdad0486..555e020b 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,3 +1,4 @@ +import pathlib import unittest from model_server.models import DummyImageToImageModel from model_server.session import Session @@ -100,6 +101,16 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual(len(find_kv), 1) self.assertEqual(find_kv[mid]['params'], p1) + def test_session_finds_existing_model_with_different_path_formats(self): + sesh = Session() + MC = DummyImageToImageModel + p1 = {'path': 'c:\\windows\\dummy.pa'} + p2 = {'path': 'c:/windows/dummy.pa'} + sesh.load_model(MC, params=p1) + assert pathlib.Path(p1['path']) == pathlib.Path(p2['path']) + find_kv = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) + self.assertEqual(len(find_kv), 1) + def test_change_output_path(self): import pathlib sesh = Session() -- GitLab