Skip to content
Snippets Groups Projects
api.py 2.88 KiB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from base.models import DummyInstanceSegmentationModel, DummySemanticSegmentationModel
from base.session import session, InvalidPathError
from base.validators import validate_workflow_inputs
from base.workflows import classify_pixels
from extensions.ilastik.workflows import infer_px_then_ob_model

app = FastAPI(debug=True)

import model_server.extensions.ilastik.router
app.include_router(model_server.extensions.ilastik.router.router)

@app.on_event("startup")
def startup():
    pass

@app.get('/')
def read_root():
    return {'success': True}

class BounceBackParams(BaseModel):
    par1: str
    par2: list

@app.put('/bounce_back')
def list_bounce_back(params: BounceBackParams):
    return {'success': True, 'params': {'par1': params.par1, 'par2': params.par2}}

@app.get('/paths')
def list_session_paths():
    return session.get_paths()

@app.get('/status')
def show_session_status():
    return {
        'status': 'running',
        'models': session.describe_loaded_models(),
        'paths': session.get_paths(),
    }

def change_path(key, path):
    try:
        if session.get_paths()[key] == path:
            return session.get_paths()
        session.set_data_directory(key, path)
    except InvalidPathError as e:
        raise HTTPException(
            status_code=404,
            detail=e.__str__(),
        )
    session.log_info(f'Change {key} path to {path}')
    return session.get_paths()

@app.put('/paths/watch_input')
def watch_input_path(path: str):
    return change_path('inbound_images', path)

@app.put('/paths/watch_output')
def watch_output_path(path: str):
    return change_path('outbound_images', path)

@app.get('/session/restart')
def restart_session(root: str = None) -> dict:
    session.restart(root=root)
    return session.describe_loaded_models()

@app.get('/session/logs')
def list_session_log() -> list:
    return session.get_log_data()

@app.get('/models')
def list_active_models():
    return session.describe_loaded_models()

@app.put('/models/dummy_semantic/load/')
def load_dummy_model() -> dict:
    mid = session.load_model(DummySemanticSegmentationModel)
    session.log_info(f'Loaded model {mid}')
    return {'model_id': mid}

@app.put('/models/dummy_instance/load/')
def load_dummy_model() -> dict:
    mid = session.load_model(DummyInstanceSegmentationModel)
    session.log_info(f'Loaded model {mid}')
    return {'model_id': mid}

@app.put('/workflows/segment')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
    inpath = session.paths['inbound_images'] / input_filename
    validate_workflow_inputs([model_id], [inpath])
    record = classify_pixels(
        inpath,
        session.models[model_id]['object'],
        session.paths['outbound_images'],
        channel=channel,
    )
    session.log_info(f'Completed segmentation of {input_filename}')
    return record