diff --git a/.gitignore b/.gitignore index 7efb7f0b52ee1bf373b2783ec2f79985d28f967e..36a2805f7974aed48458a8127918b60d0e03f951 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,2 @@ */.idea/* *__pycache__* -/clients/imagej/.idea/workspace.xml -/clients/imagej/.idea/ diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index 65057cca4017b2759b5c0d7b823a6501647faae2..1413456f633d188ad646376605749861ba904066 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -68,7 +68,7 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel): ] pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert(len(pxmaps) == 1, 'ilastik generated more than one pixel map') + assert len(pxmaps) == 1, 'ilastik generated more than one pixel map' yxcz = np.moveaxis( pxmaps[0], diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py index ba01d1b90729a9b673e4a8daa5ac0523ffaa80ab..0503ebaaad89382ccfb5bd747ea9372ccccf728b 100644 --- a/extensions/ilastik/router.py +++ b/extensions/ilastik/router.py @@ -20,25 +20,20 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: :param model_class: :param project_file: (*.ilp) ilastik project filename :param duplicate: load another instance of the same project file if True; return existing one if false - :return: dictionary with single key describing model's ID + :return: dict containing model's ID """ if not duplicate: - existing_model = session.find_param_in_loaded_models('project_file', project_file) + existing_model = session.find_param_in_loaded_models('project_file', project_file, is_path=True) if existing_model is not None: return existing_model try: - result = { - 'model_id': session.load_model( - model_class, - {'project_file': project_file} - ) - } + result = session.load_model(model_class, {'project_file': project_file}) except (FileNotFoundError, ParameterExpectedError): raise HTTPException( status_code=404, detail=f'Could not load project file {project_file}', ) - return result + return {'model_id': result} @router.put('/px/load/') def load_px_model(project_file: str, duplicate: bool = True) -> dict: diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index d2498af209087113c34ffa508f5975dcfedc00cb..6e1bb0971d7aa579b9da7fe64d8627e3fe938419 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -1,3 +1,4 @@ +import pathlib import requests import unittest @@ -123,9 +124,8 @@ class TestIlastikOverApi(TestServerBaseClass): self.uri + 'ilastik/px/load/', params={'project_file': str(ilastik_classifiers['px'])}, ) - model_id = resp_load.json()['model_id'] - self.assertEqual(resp_load.status_code, 200, resp_load.json()) + model_id = resp_load.json()['model_id'] resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() @@ -155,6 +155,42 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_3rd = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) + def test_no_duplicate_model_with_different_path_formats(self): + requests.get(self.uri + 'restart') + resp_list_1 = requests.get(self.uri + 'models').json() + self.assertEqual(len(resp_list_1), 0) + ilp = ilastik_classifiers['px'] + + # create and validate two copies of the same pathname with different string formats + ilp_win = str(pathlib.PureWindowsPath(ilp)) + self.assertGreater(ilp_win.count('\\'), 0) # i.e. contains backslashes + self.assertEqual(ilp_win.count('/'), 0) + ilp_posx = ilastik_classifiers['px'].as_posix() + self.assertGreater(ilp_posx.count('/'), 0) + self.assertEqual(ilp_posx.count('\\'), 0) + self.assertEqual(pathlib.Path(ilp_win), pathlib.Path(ilp_posx)) + + # load models with these paths + requests.put( + self.uri + 'ilastik/px/load/', + params={ + 'project_file': ilp_win, + 'duplicate': False, + }, + ) + requests.put( + self.uri + 'ilastik/px/load/', + params={ + 'project_file': ilp_posx, + 'duplicate': False, + }, + ) + + # assert that only one copy of the model is loaded + resp_list_2 = requests.get(self.uri + 'models').json() + print(resp_list_2) + self.assertEqual(len(resp_list_2), 1) + def test_load_ilastik_pxmap_to_obj_model(self): resp_load = requests.put( diff --git a/model_server/session.py b/model_server/session.py index 58e802c2247a5914c2e3dd4240e7222a4f56e63f..f9ef3d42c9085cfba95e38a26ef2b11b64957a1f 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -98,7 +98,7 @@ class Session(object): Load an instance of a given model class and attach to this session's model registry :param ModelClass: subclass of Model :param params: optional parameters that are passed to the model's construct - :return: dictionary that describes all currently loaded models + :return: model_id of loaded model """ mi = ModelClass(params=params) assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' @@ -127,14 +127,20 @@ class Session(object): for k in self.models.keys() } - def find_param_in_loaded_models(self, key: str, value: str) -> dict: + def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict: """ - Returns first instance of loaded model where key and value match with .params field, or None + Returns model_id of first model where key and value match with .params field, or None + :param is_path: uses platform-independent path comparison if True """ + models = self.describe_loaded_models() for mid, det in models.items(): - if det.get('params').get(key) == value: - return {mid: det} + if is_path: + if Path(det.get('params').get(key)) == Path(value): + return mid + else: + if det.get('params').get(key) == value: + return mid return None def restart(self, **kwargs): diff --git a/tests/test_api.py b/tests/test_api.py index 5a307376886a7d8325964ee0077739e22c96966a..e6c0f7997ea1a6396d30530d964f40cb7213fbe6 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -20,6 +20,7 @@ class TestServerBaseClass(unittest.TestCase): ) self.uri = f'http://{host}:{port}/' self.server_process.start() + requests.get(self.uri + 'restart') def copy_input_file_to_server(self): from shutil import copyfile diff --git a/tests/test_session.py b/tests/test_session.py index bdad0486c4131af46e13d2edd23c54002c28957b..e3c9319205e183e1a3e1d9335a759e4bab42e3a9 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,3 +1,4 @@ +import pathlib import unittest from model_server.models import DummyImageToImageModel from model_server.session import Session @@ -96,9 +97,19 @@ class TestGetSessionObject(unittest.TestCase): # load a second model and confirm that the first is locatable by its param entry p2 = {'p2': 'def'} sesh.load_model(MC, params=p2) - find_kv = sesh.find_param_in_loaded_models('p1', 'abc') - self.assertEqual(len(find_kv), 1) - self.assertEqual(find_kv[mid]['params'], p1) + find_mid = sesh.find_param_in_loaded_models('p1', 'abc') + self.assertEqual(mid, find_mid) + self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1) + + def test_session_finds_existing_model_with_different_path_formats(self): + sesh = Session() + MC = DummyImageToImageModel + p1 = {'path': 'c:\\windows\\dummy.pa'} + p2 = {'path': 'c:/windows/dummy.pa'} + mid = sesh.load_model(MC, params=p1) + assert pathlib.Path(p1['path']) == pathlib.Path(p2['path']) + find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) + self.assertEqual(mid, find_mid) def test_change_output_path(self): import pathlib