from fastapi import FastAPI, HTTPException from model_server.ilastik import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel from model_server.model import DummyImageToImageModel, ParameterExpectedError from model_server.session import Session from model_server.workflow import infer_image_to_image from model_server.workflow_ilastik import infer_px_then_ob_model app = FastAPI(debug=True) session = Session() @app.on_event("startup") def startup(): pass @app.get('/') def read_root(): return {'success': True} @app.put('/bounce_back') def list_bounce_back(par1=None, par2=None): return {'success': True, 'params': {'par1': par1, 'par2': par2}} @app.get('/paths') def list_session_paths(): return session.get_paths() @app.get('/restart') def restart_session(root: str = None) -> dict: session.restart(root=root) return session.describe_loaded_models() @app.get('/models') def list_active_models(): return session.describe_loaded_models() @app.put('/models/dummy/load/') def load_dummy_model() -> dict: return {'model_id': session.load_model(DummyImageToImageModel)} def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict: """ Load an ilastik model of a given class and project filename. :param model_class: :param project_file: (*.ilp) ilastik project filename :param duplicate: load another instance of the same project file if True; return existing one if false :return: dictionary with single key describing model's ID """ if not duplicate: existing_model = session.find_param_in_loaded_models('project_file', project_file) if existing_model is not None: return existing_model try: result = { 'model_id': session.load_model( model_class, {'project_file': project_file} ) } except (FileNotFoundError, ParameterExpectedError): raise HTTPException( status_code=404, detail=f'Could not load project file {project_file}', ) return result @app.put('/models/ilastik/pixel_classification/load/') def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict: return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate) @app.put('/models/ilastik/object_classification/load/') def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict: return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate) def validate_workflow_inputs(model_ids, inpaths): for mid in model_ids: if mid not in session.describe_loaded_models().keys(): raise HTTPException( status_code=409, detail=f'Model {mid} has not been loaded' ) for inpa in inpaths: if not inpa.exists(): raise HTTPException( status_code=404, detail=f'Could not find file:\n{inpa}' ) @app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename validate_workflow_inputs([model_id], [inpath]) record = infer_image_to_image( inpath, session.models[model_id]['object'], session.paths['outbound_images'], channel=channel, ) session.record_workflow_run(record) return record @app.put('/models/ilastik/pixel_then_object_classification/infer') def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename validate_workflow_inputs([px_model_id, ob_model_id], [inpath]) try: record = infer_px_then_ob_model( inpath, session.models[px_model_id]['object'], session.models[ob_model_id]['object'], session.paths['outbound_images'], channel=channel ) except AssertionError: raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') return record