diff --git a/api.py b/api.py index 317c3d2181cbed89faa383748f1bb6baae464283..abd27bbfe85291362ae5d9e1691d6d3393154b09 100644 --- a/api.py +++ b/api.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, HTTPException -from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel +from model_server.ilastik import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel from model_server.model import DummyImageToImageModel, ParameterExpectedError from model_server.session import Session from model_server.workflow import infer_image_to_image @@ -38,7 +38,18 @@ def list_active_models(): def load_dummy_model() -> dict: return {'model_id': session.load_model(DummyImageToImageModel)} -def load_ilastik_model(model_class, project_file): +def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: + """ + Load an ilastik model of a given class and project filename. + :param model_class: + :param project_file: (*.ilp) ilastik project filename + :param duplicate: load another instance of the same project file if True; return existing one if false + :return: dictionary with single key describing model's ID + """ + if not duplicate: + existing_model = session.find_param_in_loaded_models('project_file', project_file) + if existing_model is not None: + return existing_model try: result = { 'model_id': session.load_model( @@ -54,12 +65,12 @@ def load_ilastik_model(model_class, project_file): return result @app.put('/models/ilastik/pixel_classification/load/') -def load_ilastik_pixel_classification_model(project_file: str) -> dict: - return load_ilastik_model(IlastikPixelClassifierModel, project_file) +def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict: + return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate) @app.put('/models/ilastik/object_classification/load/') -def load_ilastik_object_classification_model(project_file: str) -> dict: - return load_ilastik_model(IlastikObjectClassifierModel, project_file) +def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict: + return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate) def validate_workflow_inputs(model_ids, inpaths): for mid in model_ids: diff --git a/model_server/session.py b/model_server/session.py index d53654f84b9e3c74f9fd621f5cab6836653bdea6..eb82d1b6f089ccf990b4ca8fadff93f893e88a20 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -125,6 +125,16 @@ class Session(object): for k in self.models.keys() } + def find_param_in_loaded_models(self, key: str, value: str) -> dict: + """ + Returns first instance of loaded model where key and value match with .params field, or None + """ + models = self.describe_loaded_models() + for mid, det in models.items(): + if det.get('params').get(key) == value: + return {mid: det} + return None + def restart(self, **kwargs): self.__init__(**kwargs) diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index 4af941d12e9f8f01d489d94093cb8503e273b67d..747de17a76d1cd487dd3c304f002b1d4a1029b83 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -132,6 +132,29 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel') return model_id + def test_load_another_ilastik_pixel_model(self): + model_id = self.test_load_ilastik_pixel_model() + resp_list_1st = requests.get(self.uri + 'models').json() + self.assertEqual(len(resp_list_1st), 1, resp_list_1st) + resp_load_2nd = requests.put( + self.uri + 'models/ilastik/pixel_classification/load/', + params={ + 'project_file': str(conf.testing.ilastik['pixel_classifier']), + 'duplicate': True, + }, + ) + resp_list_2nd = requests.get(self.uri + 'models').json() + self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) + resp_load_3rd = requests.put( + self.uri + 'models/ilastik/pixel_classification/load/', + params={ + 'project_file': str(conf.testing.ilastik['pixel_classifier']), + 'duplicate': False, + }, + ) + resp_list_3rd = requests.get(self.uri + 'models').json() + self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) + def test_load_ilastik_object_model(self): resp_load = requests.put( diff --git a/tests/test_session.py b/tests/test_session.py index 614ad5d107635ed19a6c651493389789d7eb2939..a03e1ad1b45a43ef5fe217935c5ed13a8431410c 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -79,8 +79,17 @@ class TestGetSessionObject(unittest.TestCase): def test_session_loads_model_with_params(self): sesh = Session() MC = DummyImageToImageModel - p = {'p1': 'abc'} - success = sesh.load_model(MC, params=p) + p1 = {'p1': 'abc'} + success = sesh.load_model(MC, params=p1) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() - self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p) + mid = MC.__name__ + '_00' + self.assertEqual(loaded_models[mid]['params'], p1) + + # load a second model and confirm that the first is locatable by its param entry + p2 = {'p2': 'def'} + sesh.load_model(MC, params=p2) + find_kv = sesh.find_param_in_loaded_models('p1', 'abc') + self.assertEqual(len(find_kv), 1) + self.assertEqual(find_kv[mid]['params'], p1) +