Skip to content
Snippets Groups Projects
Commit e5661bce authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Consolidated ilastik model load API endpoints; report out error when loading project file

parent 2fd272e1
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ from typing import Dict
from fastapi import FastAPI, HTTPException
from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel
from model_server.model import DummyImageToImageModel
from model_server.model import DummyImageToImageModel, ParameterExpectedError
from model_server.session import Session
from model_server.workflow import infer_image_to_image
......@@ -39,23 +39,28 @@ def list_active_models():
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyImageToImageModel)}
def load_ilastik_model(model_class, project_file):
try:
result = {
'model_id': session.load_model(
model_class,
{'project_file': project_file}
)
}
except (FileNotFoundError, ParameterExpectedError):
raise HTTPException(
status_code=404,
detail=f'Could not load project file {project_file}',
)
return result
@app.put('/models/ilastik/pixel_classification/load/')
def load_ilastik_pixel_classification_model(project_file: str) -> dict:
return {
'model_id': session.load_model(
IlastikPixelClassifierModel,
{'project_file': project_file}
)
}
return load_ilastik_model(IlastikPixelClassifierModel, project_file)
@app.put('/models/ilastik/object_classification/load/')
def load_ilastik_object_classification_model(project_file: str) -> dict:
return {
'model_id': session.load_model(
IlastikObjectClassifierModel,
{'project_file': project_file}
)
}
return load_ilastik_model(IlastikObjectClassifierModel, project_file)
@app.put('/infer/from_image_file')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
......
......@@ -109,6 +109,15 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertGreater(result.timer_results['inference'], 1.0)
class TestIlastikOverApi(TestServerBaseClass):
def test_httpexception_if_incorrect_project_file_loaded(self):
resp_load = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
params={'project_file': 'improper.ilp'},
)
self.assertEqual(resp_load.status_code, 404)
def test_load_ilastik_pixel_model(self):
resp_load = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment