from fastapi import FastAPI, HTTPException from model_server.models import DummyImageToImageModel from model_server.session import Session, InvalidPathError from model_server.validators import validate_workflow_inputs from model_server.workflows import infer_image_to_image from extensions.ilastik.workflows import infer_px_then_ob_model app = FastAPI(debug=True) session = Session() import extensions.ilastik.router app.include_router(extensions.ilastik.router.router) @app.on_event("startup") def startup(): pass @app.get('/') def read_root(): return {'success': True} @app.put('/bounce_back') def list_bounce_back(par1=None, par2=None): return {'success': True, 'params': {'par1': par1, 'par2': par2}} @app.get('/paths') def list_session_paths(): return session.get_paths() @app.get('/status') def show_session_status(): return { 'status': 'running', 'models': session.describe_loaded_models(), 'paths': session.get_paths(), } def change_path(key, path): try: if session.get_paths()[key] == path: return session.get_paths() session.set_data_directory(key, path) except InvalidPathError as e: raise HTTPException( status_code=404, detail=e.__str__(), ) return session.get_paths() @app.put('/paths/watch_input') def watch_input_path(path: str): return change_path('inbound_images', path) @app.put('/paths/watch_output') def watch_input_path(path: str): return change_path('outbound_images', path) @app.get('/restart') def restart_session(root: str = None) -> dict: session.restart(root=root) return session.describe_loaded_models() @app.get('/models') def list_active_models(): return session.describe_loaded_models() @app.put('/models/dummy/load/') def load_dummy_model() -> dict: return {'model_id': session.load_model(DummyImageToImageModel)} @app.put('/infer/from_image_file') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename validate_workflow_inputs([model_id], [inpath]) record = infer_image_to_image( inpath, session.models[model_id]['object'], session.paths['outbound_images'], channel=channel, ) session.record_workflow_run(record) return record