Skip to content
Snippets Groups Projects
Commit c8c58fcd authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Validated non-duplication when loading ilastik model with same project file...

Validated non-duplication when loading ilastik model with same project file name, irrespective of path formatting
parent 0a2ca539
No related branches found
No related tags found
No related merge requests found
*/.idea/*
*__pycache__*
/clients/imagej/.idea/workspace.xml
/clients/imagej/.idea/
......@@ -68,7 +68,7 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel):
]
pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert(len(pxmaps) == 1, 'ilastik generated more than one pixel map')
assert len(pxmaps) == 1, 'ilastik generated more than one pixel map'
yxcz = np.moveaxis(
pxmaps[0],
......
......@@ -20,25 +20,20 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file:
:param model_class:
:param project_file: (*.ilp) ilastik project filename
:param duplicate: load another instance of the same project file if True; return existing one if false
:return: dictionary with single key describing model's ID
:return: dict containing model's ID
"""
if not duplicate:
existing_model = session.find_param_in_loaded_models('project_file', project_file)
existing_model = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
if existing_model is not None:
return existing_model
try:
result = {
'model_id': session.load_model(
model_class,
{'project_file': project_file}
)
}
result = session.load_model(model_class, {'project_file': project_file})
except (FileNotFoundError, ParameterExpectedError):
raise HTTPException(
status_code=404,
detail=f'Could not load project file {project_file}',
)
return result
return {'model_id': result}
@router.put('/px/load/')
def load_px_model(project_file: str, duplicate: bool = True) -> dict:
......
import pathlib
import requests
import unittest
......@@ -123,9 +124,8 @@ class TestIlastikOverApi(TestServerBaseClass):
self.uri + 'ilastik/px/load/',
params={'project_file': str(ilastik_classifiers['px'])},
)
model_id = resp_load.json()['model_id']
self.assertEqual(resp_load.status_code, 200, resp_load.json())
model_id = resp_load.json()['model_id']
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
......@@ -155,6 +155,42 @@ class TestIlastikOverApi(TestServerBaseClass):
resp_list_3rd = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
def test_no_duplicate_model_with_different_path_formats(self):
requests.get(self.uri + 'restart')
resp_list_1 = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_1), 0)
ilp = ilastik_classifiers['px']
# create and validate two copies of the same pathname with different string formats
ilp_win = str(pathlib.PureWindowsPath(ilp))
self.assertGreater(ilp_win.count('\\'), 0) # i.e. contains backslashes
self.assertEqual(ilp_win.count('/'), 0)
ilp_posx = ilastik_classifiers['px'].as_posix()
self.assertGreater(ilp_posx.count('/'), 0)
self.assertEqual(ilp_posx.count('\\'), 0)
self.assertEqual(pathlib.Path(ilp_win), pathlib.Path(ilp_posx))
# load models with these paths
requests.put(
self.uri + 'ilastik/px/load/',
params={
'project_file': ilp_win,
'duplicate': False,
},
)
requests.put(
self.uri + 'ilastik/px/load/',
params={
'project_file': ilp_posx,
'duplicate': False,
},
)
# assert that only one copy of the model is loaded
resp_list_2 = requests.get(self.uri + 'models').json()
print(resp_list_2)
self.assertEqual(len(resp_list_2), 1)
def test_load_ilastik_pxmap_to_obj_model(self):
resp_load = requests.put(
......
......@@ -98,7 +98,7 @@ class Session(object):
Load an instance of a given model class and attach to this session's model registry
:param ModelClass: subclass of Model
:param params: optional parameters that are passed to the model's construct
:return: dictionary that describes all currently loaded models
:return: model_id of loaded model
"""
mi = ModelClass(params=params)
assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
......@@ -127,14 +127,20 @@ class Session(object):
for k in self.models.keys()
}
def find_param_in_loaded_models(self, key: str, value: str) -> dict:
def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict:
"""
Returns first instance of loaded model where key and value match with .params field, or None
Returns model_id of first model where key and value match with .params field, or None
:param is_path: uses platform-independent path comparison if True
"""
models = self.describe_loaded_models()
for mid, det in models.items():
if det.get('params').get(key) == value:
return {mid: det}
if is_path:
if Path(det.get('params').get(key)) == Path(value):
return mid
else:
if det.get('params').get(key) == value:
return mid
return None
def restart(self, **kwargs):
......
......@@ -20,6 +20,7 @@ class TestServerBaseClass(unittest.TestCase):
)
self.uri = f'http://{host}:{port}/'
self.server_process.start()
requests.get(self.uri + 'restart')
def copy_input_file_to_server(self):
from shutil import copyfile
......
import pathlib
import unittest
from model_server.models import DummyImageToImageModel
from model_server.session import Session
......@@ -96,9 +97,19 @@ class TestGetSessionObject(unittest.TestCase):
# load a second model and confirm that the first is locatable by its param entry
p2 = {'p2': 'def'}
sesh.load_model(MC, params=p2)
find_kv = sesh.find_param_in_loaded_models('p1', 'abc')
self.assertEqual(len(find_kv), 1)
self.assertEqual(find_kv[mid]['params'], p1)
find_mid = sesh.find_param_in_loaded_models('p1', 'abc')
self.assertEqual(mid, find_mid)
self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1)
def test_session_finds_existing_model_with_different_path_formats(self):
sesh = Session()
MC = DummyImageToImageModel
p1 = {'path': 'c:\\windows\\dummy.pa'}
p2 = {'path': 'c:/windows/dummy.pa'}
mid = sesh.load_model(MC, params=p1)
assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
self.assertEqual(mid, find_mid)
def test_change_output_path(self):
import pathlib
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment