diff --git a/api.py b/api.py index ace225b4df645a8bde6eb65d8acafe0adfcbc604..e28446ad27ce88ff56f286f8badfa0c42f51dce8 100644 --- a/api.py +++ b/api.py @@ -3,7 +3,7 @@ from typing import Dict from fastapi import FastAPI, HTTPException from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel -from model_server.model import DummyImageToImageModel +from model_server.model import DummyImageToImageModel, ParameterExpectedError from model_server.session import Session from model_server.workflow import infer_image_to_image @@ -39,23 +39,28 @@ def list_active_models(): def load_dummy_model() -> dict: return {'model_id': session.load_model(DummyImageToImageModel)} +def load_ilastik_model(model_class, project_file): + try: + result = { + 'model_id': session.load_model( + model_class, + {'project_file': project_file} + ) + } + except (FileNotFoundError, ParameterExpectedError): + raise HTTPException( + status_code=404, + detail=f'Could not load project file {project_file}', + ) + return result + @app.put('/models/ilastik/pixel_classification/load/') def load_ilastik_pixel_classification_model(project_file: str) -> dict: - return { - 'model_id': session.load_model( - IlastikPixelClassifierModel, - {'project_file': project_file} - ) - } + return load_ilastik_model(IlastikPixelClassifierModel, project_file) @app.put('/models/ilastik/object_classification/load/') def load_ilastik_object_classification_model(project_file: str) -> dict: - return { - 'model_id': session.load_model( - IlastikObjectClassifierModel, - {'project_file': project_file} - ) - } + return load_ilastik_model(IlastikObjectClassifierModel, project_file) @app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index 0ef8d0c4ef102ba7053131024b57d344b226fe3f..1c55a9a63f4dfd64900e1dd25fdad7800561dee0 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -109,6 +109,15 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertGreater(result.timer_results['inference'], 1.0) class TestIlastikOverApi(TestServerBaseClass): + + def test_httpexception_if_incorrect_project_file_loaded(self): + resp_load = requests.put( + self.uri + 'models/ilastik/pixel_classification/load/', + params={'project_file': 'improper.ilp'}, + ) + self.assertEqual(resp_load.status_code, 404) + + def test_load_ilastik_pixel_model(self): resp_load = requests.put( self.uri + 'models/ilastik/pixel_classification/load/',