From fe4dda4f3edc7d60f3c65490b1ed8cb10cd0c48f Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 7 Sep 2023 16:38:36 +0200 Subject: [PATCH] Option to skip loading a model if its project file is already in a loaded one --- api.py | 23 +++++++++++++++++------ model_server/session.py | 10 ++++++++++ tests/test_ilastik.py | 23 +++++++++++++++++++++++ tests/test_session.py | 15 ++++++++++++--- 4 files changed, 62 insertions(+), 9 deletions(-) diff --git a/api.py b/api.py index 317c3d21..abd27bbf 100644 --- a/api.py +++ b/api.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, HTTPException -from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel +from model_server.ilastik import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel from model_server.model import DummyImageToImageModel, ParameterExpectedError from model_server.session import Session from model_server.workflow import infer_image_to_image @@ -38,7 +38,18 @@ def list_active_models(): def load_dummy_model() -> dict: return {'model_id': session.load_model(DummyImageToImageModel)} -def load_ilastik_model(model_class, project_file): +def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: + """ + Load an ilastik model of a given class and project filename. + :param model_class: + :param project_file: (*.ilp) ilastik project filename + :param duplicate: load another instance of the same project file if True; return existing one if false + :return: dictionary with single key describing model's ID + """ + if not duplicate: + existing_model = session.find_param_in_loaded_models('project_file', project_file) + if existing_model is not None: + return existing_model try: result = { 'model_id': session.load_model( @@ -54,12 +65,12 @@ def load_ilastik_model(model_class, project_file): return result @app.put('/models/ilastik/pixel_classification/load/') -def load_ilastik_pixel_classification_model(project_file: str) -> dict: - return load_ilastik_model(IlastikPixelClassifierModel, project_file) +def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict: + return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate) @app.put('/models/ilastik/object_classification/load/') -def load_ilastik_object_classification_model(project_file: str) -> dict: - return load_ilastik_model(IlastikObjectClassifierModel, project_file) +def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict: + return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate) def validate_workflow_inputs(model_ids, inpaths): for mid in model_ids: diff --git a/model_server/session.py b/model_server/session.py index d53654f8..eb82d1b6 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -125,6 +125,16 @@ class Session(object): for k in self.models.keys() } + def find_param_in_loaded_models(self, key: str, value: str) -> dict: + """ + Returns first instance of loaded model where key and value match with .params field, or None + """ + models = self.describe_loaded_models() + for mid, det in models.items(): + if det.get('params').get(key) == value: + return {mid: det} + return None + def restart(self, **kwargs): self.__init__(**kwargs) diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index 4af941d1..747de17a 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -132,6 +132,29 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel') return model_id + def test_load_another_ilastik_pixel_model(self): + model_id = self.test_load_ilastik_pixel_model() + resp_list_1st = requests.get(self.uri + 'models').json() + self.assertEqual(len(resp_list_1st), 1, resp_list_1st) + resp_load_2nd = requests.put( + self.uri + 'models/ilastik/pixel_classification/load/', + params={ + 'project_file': str(conf.testing.ilastik['pixel_classifier']), + 'duplicate': True, + }, + ) + resp_list_2nd = requests.get(self.uri + 'models').json() + self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) + resp_load_3rd = requests.put( + self.uri + 'models/ilastik/pixel_classification/load/', + params={ + 'project_file': str(conf.testing.ilastik['pixel_classifier']), + 'duplicate': False, + }, + ) + resp_list_3rd = requests.get(self.uri + 'models').json() + self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) + def test_load_ilastik_object_model(self): resp_load = requests.put( diff --git a/tests/test_session.py b/tests/test_session.py index 614ad5d1..a03e1ad1 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -79,8 +79,17 @@ class TestGetSessionObject(unittest.TestCase): def test_session_loads_model_with_params(self): sesh = Session() MC = DummyImageToImageModel - p = {'p1': 'abc'} - success = sesh.load_model(MC, params=p) + p1 = {'p1': 'abc'} + success = sesh.load_model(MC, params=p1) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() - self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p) + mid = MC.__name__ + '_00' + self.assertEqual(loaded_models[mid]['params'], p1) + + # load a second model and confirm that the first is locatable by its param entry + p2 = {'p2': 'def'} + sesh.load_model(MC, params=p2) + find_kv = sesh.find_param_in_loaded_models('p1', 'abc') + self.assertEqual(len(find_kv), 1) + self.assertEqual(find_kv[mid]['params'], p1) + -- GitLab