diff --git a/api.py b/api.py index 45ee844fcd17d6f33a657b615aba446c02a6608f..fdec6903976aa79ace5630cb9ffcbdf9848ae341 100644 --- a/api.py +++ b/api.py @@ -1,14 +1,17 @@ -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI -from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel -from model_server.models import DummyImageToImageModel, ParameterExpectedError +from model_server.models import DummyImageToImageModel from model_server.session import Session +from model_server.validators import validate_workflow_inputs from model_server.workflows import infer_image_to_image from extensions.ilastik.workflows import infer_px_then_ob_model app = FastAPI(debug=True) session = Session() +import extensions.ilastik.router +app.include_router(extensions.ilastik.router.router) + @app.on_event("startup") def startup(): pass @@ -38,54 +41,6 @@ def list_active_models(): def load_dummy_model() -> dict: return {'model_id': session.load_model(DummyImageToImageModel)} -def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: - """ - Load an ilastik model of a given class and project filename. - :param model_class: - :param project_file: (*.ilp) ilastik project filename - :param duplicate: load another instance of the same project file if True; return existing one if false - :return: dictionary with single key describing model's ID - """ - if not duplicate: - existing_model = session.find_param_in_loaded_models('project_file', project_file) - if existing_model is not None: - return existing_model - try: - result = { - 'model_id': session.load_model( - model_class, - {'project_file': project_file} - ) - } - except (FileNotFoundError, ParameterExpectedError): - raise HTTPException( - status_code=404, - detail=f'Could not load project file {project_file}', - ) - return result - -@app.put('/models/ilastik/pixel_classification/load/') -def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict: - return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate) - -@app.put('/models/ilastik/object_classification/load/') -def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict: - return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate) - -def validate_workflow_inputs(model_ids, inpaths): - for mid in model_ids: - if mid not in session.describe_loaded_models().keys(): - raise HTTPException( - status_code=409, - detail=f'Model {mid} has not been loaded' - ) - for inpa in inpaths: - if not inpa.exists(): - raise HTTPException( - status_code=404, - detail=f'Could not find file:\n{inpa}' - ) - @app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename @@ -97,20 +52,4 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: channel=channel, ) session.record_workflow_run(record) - return record - -@app.put('/models/ilastik/pixel_then_object_classification/infer') -def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: - inpath = session.paths['inbound_images'] / input_filename - validate_workflow_inputs([px_model_id, ob_model_id], [inpath]) - try: - record = infer_px_then_ob_model( - inpath, - session.models[px_model_id]['object'], - session.models[ob_model_id]['object'], - session.paths['outbound_images'], - channel=channel - ) - except AssertionError: - raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') return record \ No newline at end of file diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py new file mode 100644 index 0000000000000000000000000000000000000000..155d8fa057e22c6709fa683b39c9d6c396ebc90b --- /dev/null +++ b/extensions/ilastik/router.py @@ -0,0 +1,65 @@ +from fastapi import APIRouter, HTTPException + +from model_server.session import Session +from model_server.validators import validate_workflow_inputs + +from extensions.ilastik.models import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel +from model_server.models import ParameterExpectedError +from extensions.ilastik.workflows import infer_px_then_ob_model + +router = APIRouter( + prefix='/ilastik', + tags=['ilastik'], +) + +session = Session() + +def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: + """ + Load an ilastik model of a given class and project filename. + :param model_class: + :param project_file: (*.ilp) ilastik project filename + :param duplicate: load another instance of the same project file if True; return existing one if false + :return: dictionary with single key describing model's ID + """ + if not duplicate: + existing_model = session.find_param_in_loaded_models('project_file', project_file) + if existing_model is not None: + return existing_model + try: + result = { + 'model_id': session.load_model( + model_class, + {'project_file': project_file} + ) + } + except (FileNotFoundError, ParameterExpectedError): + raise HTTPException( + status_code=404, + detail=f'Could not load project file {project_file}', + ) + return result + +@router.put('/pixel_classification/load/') +def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict: + return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate) + +@router.put('/object_classification/load/') +def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict: + return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate) + +@router.put('/pixel_then_object_classification/infer') +def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: + inpath = session.paths['inbound_images'] / input_filename + validate_workflow_inputs([px_model_id, ob_model_id], [inpath]) + try: + record = infer_px_then_ob_model( + inpath, + session.models[px_model_id]['object'], + session.models[ob_model_id]['object'], + session.paths['outbound_images'], + channel=channel + ) + except AssertionError: + raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') + return record \ No newline at end of file diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index f8ec7e9e683e4a96c605a0c748e8d6ed1b5fd04a..555d8286c46121e409dab94fa31d2c75ac5b9a37 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -112,7 +112,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_httpexception_if_incorrect_project_file_loaded(self): resp_load = requests.put( - self.uri + 'models/ilastik/pixel_classification/load/', + self.uri + 'ilastik/pixel_classification/load/', params={'project_file': 'improper.ilp'}, ) self.assertEqual(resp_load.status_code, 404) @@ -120,7 +120,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_load_ilastik_pixel_model(self): resp_load = requests.put( - self.uri + 'models/ilastik/pixel_classification/load/', + self.uri + 'ilastik/pixel_classification/load/', params={'project_file': str(conf.testing.ilastik['pixel_classifier'])}, ) model_id = resp_load.json()['model_id'] @@ -137,7 +137,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_1st = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_1st), 1, resp_list_1st) resp_load_2nd = requests.put( - self.uri + 'models/ilastik/pixel_classification/load/', + self.uri + 'ilastik/pixel_classification/load/', params={ 'project_file': str(conf.testing.ilastik['pixel_classifier']), 'duplicate': True, @@ -146,7 +146,7 @@ class TestIlastikOverApi(TestServerBaseClass): resp_list_2nd = requests.get(self.uri + 'models').json() self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd) resp_load_3rd = requests.put( - self.uri + 'models/ilastik/pixel_classification/load/', + self.uri + 'ilastik/pixel_classification/load/', params={ 'project_file': str(conf.testing.ilastik['pixel_classifier']), 'duplicate': False, @@ -158,7 +158,7 @@ class TestIlastikOverApi(TestServerBaseClass): def test_load_ilastik_object_model(self): resp_load = requests.put( - self.uri + 'models/ilastik/object_classification/load/', + self.uri + 'ilastik/object_classification/load/', params={'project_file': str(conf.testing.ilastik['object_classifier'])}, ) model_id = resp_load.json()['model_id'] @@ -190,7 +190,7 @@ class TestIlastikOverApi(TestServerBaseClass): ob_model_id = self.test_load_ilastik_object_model() resp_infer = requests.put( - self.uri + f'models/ilastik/pixel_then_object_classification/infer/', + self.uri + f'ilastik/pixel_then_object_classification/infer/', params={ 'px_model_id': px_model_id, 'ob_model_id': ob_model_id, diff --git a/model_server/validators.py b/model_server/validators.py new file mode 100644 index 0000000000000000000000000000000000000000..56e251bc0bceee9c84d4254e746459d0bcf6c775 --- /dev/null +++ b/model_server/validators.py @@ -0,0 +1,19 @@ +from fastapi import HTTPException + +from model_server.session import Session + +session = Session() + +def validate_workflow_inputs(model_ids, inpaths): + for mid in model_ids: + if mid not in session.describe_loaded_models().keys(): + raise HTTPException( + status_code=409, + detail=f'Model {mid} has not been loaded' + ) + for inpa in inpaths: + if not inpa.exists(): + raise HTTPException( + status_code=404, + detail=f'Could not find file:\n{inpa}' + ) \ No newline at end of file