From e5661bce97e0c14d0265278e664fe059b6fb14b4 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Tue, 5 Sep 2023 15:07:12 +0200 Subject: [PATCH] Consolidated ilastik model load API endpoints; report out error when loading project file --- api.py | 31 ++++++++++++++++++------------- tests/test_ilastik.py | 9 +++++++++ 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/api.py b/api.py index ace225b4..e28446ad 100644 --- a/api.py +++ b/api.py @@ -3,7 +3,7 @@ from typing import Dict from fastapi import FastAPI, HTTPException from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel -from model_server.model import DummyImageToImageModel +from model_server.model import DummyImageToImageModel, ParameterExpectedError from model_server.session import Session from model_server.workflow import infer_image_to_image @@ -39,23 +39,28 @@ def list_active_models(): def load_dummy_model() -> dict: return {'model_id': session.load_model(DummyImageToImageModel)} +def load_ilastik_model(model_class, project_file): + try: + result = { + 'model_id': session.load_model( + model_class, + {'project_file': project_file} + ) + } + except (FileNotFoundError, ParameterExpectedError): + raise HTTPException( + status_code=404, + detail=f'Could not load project file {project_file}', + ) + return result + @app.put('/models/ilastik/pixel_classification/load/') def load_ilastik_pixel_classification_model(project_file: str) -> dict: - return { - 'model_id': session.load_model( - IlastikPixelClassifierModel, - {'project_file': project_file} - ) - } + return load_ilastik_model(IlastikPixelClassifierModel, project_file) @app.put('/models/ilastik/object_classification/load/') def load_ilastik_object_classification_model(project_file: str) -> dict: - return { - 'model_id': session.load_model( - IlastikObjectClassifierModel, - {'project_file': project_file} - ) - } + return load_ilastik_model(IlastikObjectClassifierModel, project_file) @app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index 0ef8d0c4..1c55a9a6 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -109,6 +109,15 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertGreater(result.timer_results['inference'], 1.0) class TestIlastikOverApi(TestServerBaseClass): + + def test_httpexception_if_incorrect_project_file_loaded(self): + resp_load = requests.put( + self.uri + 'models/ilastik/pixel_classification/load/', + params={'project_file': 'improper.ilp'}, + ) + self.assertEqual(resp_load.status_code, 404) + + def test_load_ilastik_pixel_model(self): resp_load = requests.put( self.uri + 'models/ilastik/pixel_classification/load/', -- GitLab