diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py index 0503ebaaad89382ccfb5bd747ea9372ccccf728b..14ec32587349f33476bc62f7e37ba6d31e49445d 100644 --- a/extensions/ilastik/router.py +++ b/extensions/ilastik/router.py @@ -23,9 +23,9 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: :return: dict containing model's ID """ if not duplicate: - existing_model = session.find_param_in_loaded_models('project_file', project_file, is_path=True) - if existing_model is not None: - return existing_model + existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) + if existing_model_id is not None: + return {'model_id': existing_model_id} try: result = session.load_model(model_class, {'project_file': project_file}) except (FileNotFoundError, ParameterExpectedError): diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index 6e1bb0971d7aa579b9da7fe64d8627e3fe938419..79cbe3b49630468c8c97558e22923cd522b26cf3 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -171,20 +171,21 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(pathlib.Path(ilp_win), pathlib.Path(ilp_posx)) # load models with these paths - requests.put( + resp1 = requests.put( self.uri + 'ilastik/px/load/', params={ 'project_file': ilp_win, 'duplicate': False, }, ) - requests.put( + resp2 = requests.put( self.uri + 'ilastik/px/load/', params={ 'project_file': ilp_posx, 'duplicate': False, }, ) + self.assertEqual(resp1.json(), resp2.json()) # assert that only one copy of the model is loaded resp_list_2 = requests.get(self.uri + 'models').json() diff --git a/model_server/session.py b/model_server/session.py index f9ef3d42c9085cfba95e38a26ef2b11b64957a1f..061711a5654a4b80d95209a422461aa406abc7b1 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -127,7 +127,7 @@ class Session(object): for k in self.models.keys() } - def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict: + def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> str: """ Returns model_id of first model where key and value match with .params field, or None :param is_path: uses platform-independent path comparison if True