Skip to content
Snippets Groups Projects
api.py 2.31 KiB
from fastapi import FastAPI, HTTPException

from model_server.models import DummyImageToImageModel
from model_server.session import Session, InvalidPathError
from model_server.validators import validate_workflow_inputs
from model_server.workflows import infer_image_to_image
from extensions.ilastik.workflows import infer_px_then_ob_model

app = FastAPI(debug=True)
session = Session()

import extensions.ilastik.router
app.include_router(extensions.ilastik.router.router)

@app.on_event("startup")
def startup():
    pass

@app.get('/')
def read_root():
    return {'success': True}

@app.put('/bounce_back')
def list_bounce_back(par1=None, par2=None):
    return {'success': True, 'params': {'par1': par1, 'par2': par2}}

@app.get('/paths')
def list_session_paths():
    return session.get_paths()

@app.get('/status')
def show_session_status():
    return {
        'status': 'running',
        'models': session.describe_loaded_models(),
        'paths': session.get_paths(),
    }

def change_path(key, path):
    try:
        if session.get_paths()[key] == path:
            return session.get_paths()
        session.set_data_directory(key, path)
    except InvalidPathError as e:
        raise HTTPException(
            status_code=404,
            detail=e.__str__(),
        )
    return session.get_paths()

@app.put('/paths/watch_input')
def watch_input_path(path: str):
    return change_path('inbound_images', path)

@app.put('/paths/watch_output')
def watch_input_path(path: str):
    return change_path('outbound_images', path)

@app.get('/restart')
def restart_session(root: str = None) -> dict:
    session.restart(root=root)
    return session.describe_loaded_models()

@app.get('/models')
def list_active_models():
    return session.describe_loaded_models()

@app.put('/models/dummy/load/')
def load_dummy_model() -> dict:
    return {'model_id': session.load_model(DummyImageToImageModel)}

@app.put('/infer/from_image_file')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
    inpath = session.paths['inbound_images'] / input_filename
    validate_workflow_inputs([model_id], [inpath])
    record = infer_image_to_image(
        inpath,
        session.models[model_id]['object'],
        session.paths['outbound_images'],
        channel=channel,
    )
    session.record_workflow_run(record)
    return record