diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py index ba01d1b90729a9b673e4a8daa5ac0523ffaa80ab..0503ebaaad89382ccfb5bd747ea9372ccccf728b 100644 --- a/extensions/ilastik/router.py +++ b/extensions/ilastik/router.py @@ -20,25 +20,20 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: :param model_class: :param project_file: (*.ilp) ilastik project filename :param duplicate: load another instance of the same project file if True; return existing one if false - :return: dictionary with single key describing model's ID + :return: dict containing model's ID """ if not duplicate: - existing_model = session.find_param_in_loaded_models('project_file', project_file) + existing_model = session.find_param_in_loaded_models('project_file', project_file, is_path=True) if existing_model is not None: return existing_model try: - result = { - 'model_id': session.load_model( - model_class, - {'project_file': project_file} - ) - } + result = session.load_model(model_class, {'project_file': project_file}) except (FileNotFoundError, ParameterExpectedError): raise HTTPException( status_code=404, detail=f'Could not load project file {project_file}', ) - return result + return {'model_id': result} @router.put('/px/load/') def load_px_model(project_file: str, duplicate: bool = True) -> dict: diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index 5eb4076b1e58c1d7d82de8b4179183b4f054a0ea..6e1bb0971d7aa579b9da7fe64d8627e3fe938419 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -124,9 +124,8 @@ class TestIlastikOverApi(TestServerBaseClass): self.uri + 'ilastik/px/load/', params={'project_file': str(ilastik_classifiers['px'])}, ) - model_id = resp_load.json()['model_id'] - self.assertEqual(resp_load.status_code, 200, resp_load.json()) + model_id = resp_load.json()['model_id'] resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() @@ -157,30 +156,37 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) def test_no_duplicate_model_with_different_path_formats(self): - resp_restart = requests.get(self.uri + 'restart') + requests.get(self.uri + 'restart') resp_list_1 = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_1), 0) ilp = ilastik_classifiers['px'] + + # create and validate two copies of the same pathname with different string formats ilp_win = str(pathlib.PureWindowsPath(ilp)) self.assertGreater(ilp_win.count('\\'), 0) # i.e. contains backslashes self.assertEqual(ilp_win.count('/'), 0) ilp_posx = ilastik_classifiers['px'].as_posix() self.assertGreater(ilp_posx.count('/'), 0) self.assertEqual(ilp_posx.count('\\'), 0) - resp_load_1 = requests.put( + self.assertEqual(pathlib.Path(ilp_win), pathlib.Path(ilp_posx)) + + # load models with these paths + requests.put( self.uri + 'ilastik/px/load/', params={ 'project_file': ilp_win, 'duplicate': False, }, ) - resp_load_2 = requests.put( + requests.put( self.uri + 'ilastik/px/load/', params={ 'project_file': ilp_posx, 'duplicate': False, }, ) + + # assert that only one copy of the model is loaded resp_list_2 = requests.get(self.uri + 'models').json() print(resp_list_2) self.assertEqual(len(resp_list_2), 1)