From aaa18478b03dfe2640205c88131d1b449af4248c Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 3 Nov 2023 11:31:07 +0100 Subject: [PATCH] Model-loading always returns {'model_id': _} when successfuly --- extensions/ilastik/router.py | 6 +++--- extensions/ilastik/tests/test_ilastik.py | 5 +++-- model_server/session.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py index 0503ebaa..14ec3258 100644 --- a/extensions/ilastik/router.py +++ b/extensions/ilastik/router.py @@ -23,9 +23,9 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: :return: dict containing model's ID """ if not duplicate: - existing_model = session.find_param_in_loaded_models('project_file', project_file, is_path=True) - if existing_model is not None: - return existing_model + existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) + if existing_model_id is not None: + return {'model_id': existing_model_id} try: result = session.load_model(model_class, {'project_file': project_file}) except (FileNotFoundError, ParameterExpectedError): diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index 6e1bb097..79cbe3b4 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -171,20 +171,21 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(pathlib.Path(ilp_win), pathlib.Path(ilp_posx)) # load models with these paths - requests.put( + resp1 = requests.put( self.uri + 'ilastik/px/load/', params={ 'project_file': ilp_win, 'duplicate': False, }, ) - requests.put( + resp2 = requests.put( self.uri + 'ilastik/px/load/', params={ 'project_file': ilp_posx, 'duplicate': False, }, ) + self.assertEqual(resp1.json(), resp2.json()) # assert that only one copy of the model is loaded resp_list_2 = requests.get(self.uri + 'models').json() diff --git a/model_server/session.py b/model_server/session.py index f9ef3d42..061711a5 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -127,7 +127,7 @@ class Session(object): for k in self.models.keys() } - def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> dict: + def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> str: """ Returns model_id of first model where key and value match with .params field, or None :param is_path: uses platform-independent path comparison if True -- GitLab