diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py index 20a1ef24235c0c4ee3d8cc101ffba9e92d86560b..ea8527985593b348749f8be7694196fc2dcb3199 100644 --- a/model_server/extensions/ilastik/router.py +++ b/model_server/extensions/ilastik/router.py @@ -14,42 +14,45 @@ import model_server.extensions.ilastik.pipelines.px_then_ob router.include_router(model_server.extensions.ilastik.pipelines.px_then_ob.router) @router.put('/seg/load/') -def load_px_model(p: ilm.IlastikPixelClassifierParams) -> dict: +def load_px_model(p: ilm.IlastikPixelClassifierParams, model_id=None) -> dict: """ Load an ilastik pixel classifier model from its project file """ return load_ilastik_model( ilm.IlastikPixelClassifierModel, p, + model_id=model_id, ) @router.put('/pxmap_to_obj/load/') -def load_pxmap_to_obj_model(p: ilm.IlastikParams) -> dict: +def load_pxmap_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: """ Load an ilastik object classifier from pixel predictions model from its project file """ return load_ilastik_model( ilm.IlastikObjectClassifierFromPixelPredictionsModel, p, + model_id=model_id, ) @router.put('/seg_to_obj/load/') -def load_seg_to_obj_model(p: ilm.IlastikParams) -> dict: +def load_seg_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict: """ Load an ilastik object classifier from segmentation model from its project file """ return load_ilastik_model( ilm.IlastikObjectClassifierFromSegmentationModel, p, + model_id=model_id, ) -def load_ilastik_model(model_class: ilm.IlastikModel, p: ilm.IlastikParams) -> dict: +def load_ilastik_model(model_class: ilm.IlastikModel, p: ilm.IlastikParams, model_id=None) -> dict: project_file = p.project_file if not p.duplicate: existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) if existing_model_id is not None: session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate') return {'model_id': existing_model_id} - result = session.load_model(model_class, key=p.model_id, params=p) + result = session.load_model(model_class, key=model_id, params=p) session.log_info(f'Loaded ilastik model {result} from {project_file}') return {'model_id': result} \ No newline at end of file diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index a78875f26781f5533b76454edd8636b6754b9c35..4e45dfb73310b1c35d3d674f02215271251be0af 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -217,17 +217,13 @@ class TestIlastikOverApi(TestServerTestCase): def test_load_ilastik_pixel_model(self): - resp_load = self.assertPutSuccess( + mid = self.assertPutSuccess( 'ilastik/seg/load/', body={'project_file': str(ilastik_classifiers['px']['path'])}, - ) - self.assertEqual(resp_load.status_code, 200, resp_load.json()) - model_id = resp_load.json()['model_id'] - resp_list = self.assertGetSuccess('models') - self.assertEqual(resp_list.status_code, 200) - rj = resp_list.json() - self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel') - return model_id + )['model_id'] + rl = self.assertGetSuccess('models') + self.assertEqual(rl[mid]['class'], 'IlastikPixelClassifierModel') + return mid def test_load_another_ilastik_pixel_model(self): model_id = self.test_load_ilastik_pixel_model()