Skip to content
Snippets Groups Projects
Commit 1ab12cb8 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

ilastik model loading now takes duplicate in request body

parent 621bb0cf
No related branches found
No related tags found
2 merge requests!50Release 2024.06.03,!45Issue0037
......@@ -67,7 +67,7 @@ def main(request_func, in_abspath, params):
resp = request_func(
'PUT',
'/ilastik/pixel_then_object_classification/infer',
body={
{
'px_model_id': id_px_mod,
'ob_model_id': id_ob_mod,
'input_filename': in_file,
......
......@@ -15,6 +15,7 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat
class IlastikParams(BaseModel):
project_file: str
duplicate: bool = True
class IlastikModel(Model):
......
......@@ -15,7 +15,7 @@ router = APIRouter(
session = Session()
def load_ilastik_model(model_class: ilm.IlastikModel, params: dict, duplicate=True) -> dict:
def load_ilastik_model(model_class: ilm.IlastikModel, params: ilm.IlastikParams) -> dict:
"""
Load an ilastik model of a given class and project filename.
:param model_class:
......@@ -24,7 +24,7 @@ def load_ilastik_model(model_class: ilm.IlastikModel, params: dict, duplicate=Tr
:return: dict containing model's ID
"""
project_file = params.project_file
if not duplicate:
if not params.duplicate:
existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
if existing_model_id is not None:
session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate')
......@@ -34,27 +34,24 @@ def load_ilastik_model(model_class: ilm.IlastikModel, params: dict, duplicate=Tr
return {'model_id': result}
@router.put('/seg/load/')
def load_px_model(params: ilm.IlastikPixelClassifierParams, duplicate: bool = True) -> dict:
def load_px_model(params: ilm.IlastikPixelClassifierParams) -> dict:
return load_ilastik_model(
ilm.IlastikPixelClassifierModel,
params,
duplicate=duplicate
)
@router.put('/pxmap_to_obj/load/')
def load_pxmap_to_obj_model(params: ilm.IlastikParams, duplicate: bool = True) -> dict:
def load_pxmap_to_obj_model(params: ilm.IlastikParams) -> dict:
return load_ilastik_model(
ilm.IlastikObjectClassifierFromPixelPredictionsModel,
params,
duplicate=duplicate
)
@router.put('/seg_to_obj/load/')
def load_seg_to_obj_model(params: ilm.IlastikParams, duplicate: bool = True) -> dict:
def load_seg_to_obj_model(params: ilm.IlastikParams) -> dict:
return load_ilastik_model(
ilm.IlastikObjectClassifierFromSegmentationModel,
params,
duplicate=duplicate
)
@router.put('/pixel_then_object_classification/infer')
......
......@@ -226,15 +226,13 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
resp_load_2nd = self._put(
'ilastik/seg/load/',
body={'project_file': str(ilastik_classifiers['px'])},
query={'duplicate': True},
body={'project_file': str(ilastik_classifiers['px']), 'duplicate': True},
)
resp_list_2nd = self._get('models').json()
self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
resp_load_3rd = self._put(
'ilastik/seg/load/',
body={'project_file': str(ilastik_classifiers['px'])},
query={'duplicate': False},
body={'project_file': str(ilastik_classifiers['px']), 'duplicate': False},
)
resp_list_3rd = self._get('models').json()
self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
......@@ -273,13 +271,11 @@ class TestIlastikOverApi(TestServerBaseClass):
# load models with these paths
resp1 = self._put(
'ilastik/seg/load/',
body={'project_file': ilp_win},
query={'duplicate': False},
body={'project_file': ilp_win, 'duplicate': False},
)
resp2 = self._put(
'ilastik/seg/load/',
body={'project_file': ilp_posx},
query={'duplicate': False},
body={'project_file': ilp_posx, 'duplicate': False},
)
self.assertEqual(resp1.json(), resp2.json())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment