From 1ab12cb8e024b193dadfe04c4ab4311a2a2c29f4 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 22 May 2024 16:07:36 +0200
Subject: [PATCH] ilastik model loading now takes duplicate in request body

---
 model_server/clients/ilastik_map_objects.py         |  2 +-
 model_server/extensions/ilastik/models.py           |  1 +
 model_server/extensions/ilastik/router.py           | 13 +++++--------
 .../extensions/ilastik/tests/test_ilastik.py        | 12 ++++--------
 4 files changed, 11 insertions(+), 17 deletions(-)

diff --git a/model_server/clients/ilastik_map_objects.py b/model_server/clients/ilastik_map_objects.py
index c7febda3..ee16cbad 100644
--- a/model_server/clients/ilastik_map_objects.py
+++ b/model_server/clients/ilastik_map_objects.py
@@ -67,7 +67,7 @@ def main(request_func, in_abspath, params):
 	resp = request_func(
 		'PUT',
 		'/ilastik/pixel_then_object_classification/infer',
-		body={
+		{
 			'px_model_id': id_px_mod,
 			'ob_model_id': id_ob_mod,
 			'input_filename': in_file,
diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 4a1c4c0a..e8afb238 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -15,6 +15,7 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat
 
 class IlastikParams(BaseModel):
     project_file: str
+    duplicate: bool = True
 
 class IlastikModel(Model):
 
diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py
index 9ad2733e..491e7e40 100644
--- a/model_server/extensions/ilastik/router.py
+++ b/model_server/extensions/ilastik/router.py
@@ -15,7 +15,7 @@ router = APIRouter(
 session = Session()
 
 
-def load_ilastik_model(model_class: ilm.IlastikModel, params: dict, duplicate=True) -> dict:
+def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict:
     """
     Load an ilastik model of a given class and project filename.
     :param model_class:
@@ -24,7 +24,7 @@ def load_ilastik_model(model_class: ilm.IlastikModel, params: dict, duplicate=Tr
     :return: dict containing model's ID
     """
     project_file = params.project_file
-    if not duplicate:
+    if not params.duplicate:
         existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
         if existing_model_id is not None:
             session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate')
@@ -34,27 +34,24 @@ def load_ilastik_model(model_class: ilm.IlastikModel, params: dict, duplicate=Tr
     return {'model_id': result}
 
 @router.put('/seg/load/')
-def load_px_model(params: ilm.IlastikPixelClassifierParams, duplicate: bool = True) -> dict:
+def load_px_model(params: ilm.IlastikPixelClassifierParams) -> dict:
     return load_ilastik_model(
         ilm.IlastikPixelClassifierModel,
         params,
-        duplicate=duplicate
     )
 
 @router.put('/pxmap_to_obj/load/')
-def load_pxmap_to_obj_model(params: ilm.IlastikParams, duplicate: bool = True) -> dict:
+def load_pxmap_to_obj_model(params: ilm.IlastikParams) -> dict:
     return load_ilastik_model(
         ilm.IlastikObjectClassifierFromPixelPredictionsModel,
         params,
-        duplicate=duplicate
     )
 
 @router.put('/seg_to_obj/load/')
-def load_seg_to_obj_model(params: ilm.IlastikParams, duplicate: bool = True) -> dict:
+def load_seg_to_obj_model(params: ilm.IlastikParams) -> dict:
     return load_ilastik_model(
         ilm.IlastikObjectClassifierFromSegmentationModel,
         params,
-        duplicate=duplicate
     )
 
 @router.put('/pixel_then_object_classification/infer')
diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py
index a48867c0..1907459d 100644
--- a/model_server/extensions/ilastik/tests/test_ilastik.py
+++ b/model_server/extensions/ilastik/tests/test_ilastik.py
@@ -226,15 +226,13 @@ class TestIlastikOverApi(TestServerBaseClass):
         self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
         resp_load_2nd = self._put(
             'ilastik/seg/load/',
-            body={'project_file': str(ilastik_classifiers['px'])},
-            query={'duplicate': True},
+            body={'project_file': str(ilastik_classifiers['px']), 'duplicate': True},
         )
         resp_list_2nd = self._get('models').json()
         self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
         resp_load_3rd = self._put(
             'ilastik/seg/load/',
-            body={'project_file': str(ilastik_classifiers['px'])},
-            query={'duplicate': False},
+            body={'project_file': str(ilastik_classifiers['px']), 'duplicate': False},
         )
         resp_list_3rd = self._get('models').json()
         self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
@@ -273,13 +271,11 @@ class TestIlastikOverApi(TestServerBaseClass):
         # load models with these paths
         resp1 = self._put(
             'ilastik/seg/load/',
-            body={'project_file': ilp_win},
-            query={'duplicate': False},
+            body={'project_file': ilp_win, 'duplicate': False},
         )
         resp2 = self._put(
             'ilastik/seg/load/',
-            body={'project_file': ilp_posx},
-            query={'duplicate': False},
+            body={'project_file': ilp_posx, 'duplicate': False},
         )
         self.assertEqual(resp1.json(), resp2.json())
 
-- 
GitLab