Skip to content
Snippets Groups Projects
Commit 3daaf0ba authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test covers problem with models being duplicated because of different path...

Test covers problem with models being duplicated because of different path specifications; not yet fixed at API level
parent a34e3db8
No related branches found
No related tags found
No related merge requests found
import pathlib
import requests
import unittest
......@@ -155,6 +156,35 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_list_3rd = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
def test_no_duplicate_model_with_different_path_formats(self):
resp_restart = requests.get(self.uri + 'restart')
resp_list_1 = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_1), 0)
ilp = ilastik_classifiers['px']
ilp_win = str(pathlib.PureWindowsPath(ilp))
self.assertGreater(ilp_win.count('\\'), 0) # i.e. contains backslashes
self.assertEqual(ilp_win.count('/'), 0)
ilp_posx = ilastik_classifiers['px'].as_posix()
self.assertGreater(ilp_posx.count('/'), 0)
self.assertEqual(ilp_posx.count('\\'), 0)
resp_load_1 = requests.put(
self.uri + 'ilastik/px/load/',
params={
'project_file': ilp_win,
'duplicate': False,
},
)
resp_load_2 = requests.put(
self.uri + 'ilastik/px/load/',
params={
'project_file': ilp_posx,
'duplicate': False,
},
)
resp_list_2 = requests.get(self.uri + 'models').json()
print(resp_list_2)
self.assertEqual(len(resp_list_2), 1)
def test_load_ilastik_pxmap_to_obj_model(self):
resp_load = requests.put(
......
......@@ -127,14 +127,20 @@ class Session(object):
for k in self.models.keys()
}
def find_param_in_loaded_models(self, key: str, value: str) -> dict:
def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict:
"""
Returns first instance of loaded model where key and value match with .params field, or None
:param is_path: uses platform-independent path comparison if True
"""
models = self.describe_loaded_models()
for mid, det in models.items():
if det.get('params').get(key) == value:
return {mid: det}
if is_path:
if Path(det.get('params').get(key)) == Path(value):
return {mid: det}
else:
if det.get('params').get(key) == value:
return {mid: det}
return None
def restart(self, **kwargs):
......
import pathlib
import unittest
from model_server.models import DummyImageToImageModel
from model_server.session import Session
......@@ -100,6 +101,16 @@ class TestGetSessionObject(unittest.TestCase):
self.assertEqual(len(find_kv), 1)
self.assertEqual(find_kv[mid]['params'], p1)
def test_session_finds_existing_model_with_different_path_formats(self):
sesh = Session()
MC = DummyImageToImageModel
p1 = {'path': 'c:\\windows\\dummy.pa'}
p2 = {'path': 'c:/windows/dummy.pa'}
sesh.load_model(MC, params=p1)
assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
find_kv = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
self.assertEqual(len(find_kv), 1)
def test_change_output_path(self):
import pathlib
sesh = Session()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment