From 1ab12cb8e024b193dadfe04c4ab4311a2a2c29f4 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 22 May 2024 16:07:36 +0200 Subject: [PATCH] ilastik model loading now takes duplicate in request body --- model_server/clients/ilastik_map_objects.py | 2 +- model_server/extensions/ilastik/models.py | 1 + model_server/extensions/ilastik/router.py | 13 +++++-------- .../extensions/ilastik/tests/test_ilastik.py | 12 ++++-------- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/model_server/clients/ilastik_map_objects.py b/model_server/clients/ilastik_map_objects.py index c7febda3..ee16cbad 100644 --- a/model_server/clients/ilastik_map_objects.py +++ b/model_server/clients/ilastik_map_objects.py @@ -67,7 +67,7 @@ def main(request_func, in_abspath, params): resp = request_func( 'PUT', '/ilastik/pixel_then_object_classification/infer', - body={ + { 'px_model_id': id_px_mod, 'ob_model_id': id_ob_mod, 'input_filename': in_file, diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 4a1c4c0a..e8afb238 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -15,6 +15,7 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat class IlastikParams(BaseModel): project_file: str + duplicate: bool = True class IlastikModel(Model): diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 9ad2733e..491e7e40 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -15,7 +15,7 @@ router = APIRouter( session = Session() -def load_ilastik_model(model_class: ilm.IlastikModel, params: dict, duplicate=True) -> dict: +def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict: """ Load an ilastik model of a given class and project filename. :param model_class: @@ -24,7 +24,7 @@ def load_ilastik_model(model_class: ilm.IlastikModel, params: dict, duplicate=Tr :return: dict containing model's ID """ project_file = params.project_file - if not duplicate: + if not params.duplicate: existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) if existing_model_id is not None: session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') @@ -34,27 +34,24 @@ def load_ilastik_model(model_class: ilm.IlastikModel, params: dict, duplicate=Tr return {'model_id': result} @router.put('/seg/load/') -def load_px_model(params: ilm.IlastikPixelClassifierParams, duplicate: bool = True) -> dict: +def load_px_model(params: ilm.IlastikPixelClassifierParams) -> dict: return load_ilastik_model( ilm.IlastikPixelClassifierModel, params, - duplicate=duplicate ) @router.put('/pxmap_to_obj/load/') -def load_pxmap_to_obj_model(params: ilm.IlastikParams, duplicate: bool = True) -> dict: +def load_pxmap_to_obj_model(params: ilm.IlastikParams) -> dict: return load_ilastik_model( ilm.IlastikObjectClassifierFromPixelPredictionsModel, params, - duplicate=duplicate ) @router.put('/seg_to_obj/load/') -def load_seg_to_obj_model(params: ilm.IlastikParams, duplicate: bool = True) -> dict: +def load_seg_to_obj_model(params: ilm.IlastikParams) -> dict: return load_ilastik_model( ilm.IlastikObjectClassifierFromSegmentationModel, params, - duplicate=duplicate ) @router.put('/pixel_then_object_classification/infer') diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index a48867c0..1907459d 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -226,15 +226,13 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(len(resp_list_1st), 1, resp_list_1st) resp_load_2nd = self._put( 'ilastik/seg/load/', - body={'project_file': str(ilastik_classifiers['px'])}, - query={'duplicate': True}, + body={'project_file': str(ilastik_classifiers['px']), 'duplicate': True}, ) resp_list_2nd = self._get('models').json() self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) resp_load_3rd = self._put( 'ilastik/seg/load/', - body={'project_file': str(ilastik_classifiers['px'])}, - query={'duplicate': False}, + body={'project_file': str(ilastik_classifiers['px']), 'duplicate': False}, ) resp_list_3rd = self._get('models').json() self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd) @@ -273,13 +271,11 @@ class TestIlastikOverApi(TestServerBaseClass): # load models with these paths resp1 = self._put( 'ilastik/seg/load/', - body={'project_file': ilp_win}, - query={'duplicate': False}, + body={'project_file': ilp_win, 'duplicate': False}, ) resp2 = self._put( 'ilastik/seg/load/', - body={'project_file': ilp_posx}, - query={'duplicate': False}, + body={'project_file': ilp_posx, 'duplicate': False}, ) self.assertEqual(resp1.json(), resp2.json()) -- GitLab