Newer
Older
from fastapi import FastAPI, HTTPException

Christopher Randolph Rhodes
committed
from model_server.ilastik import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel

Christopher Randolph Rhodes
committed
from model_server.model import DummyImageToImageModel, ParameterExpectedError

Christopher Randolph Rhodes
committed
from model_server.session import Session
from model_server.workflow import infer_image_to_image

Christopher Randolph Rhodes
committed
from model_server.workflow_ilastik import infer_px_then_ob_model
app = FastAPI(debug=True)
session = Session()
@app.on_event("startup")
def startup():
pass
@app.get('/')
def read_root():
@app.put('/bounce_back')
def list_bounce_back(par1=None, par2=None):
return {'success': True, 'params': {'par1': par1, 'par2': par2}}
@app.get('/paths')
def list_session_paths():
return session.get_paths()

Christopher Randolph Rhodes
committed
@app.get('/restart')
def restart_session(root: str = None) -> dict:
session.restart(root=root)

Christopher Randolph Rhodes
committed
return session.describe_loaded_models()
@app.get('/models')
def list_active_models():

Christopher Randolph Rhodes
committed
return session.describe_loaded_models()

Christopher Randolph Rhodes
committed
@app.put('/models/dummy/load/')
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyImageToImageModel)}

Christopher Randolph Rhodes
committed

Christopher Randolph Rhodes
committed
def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
"""
Load an ilastik model of a given class and project filename.
:param model_class:
:param project_file: (*.ilp) ilastik project filename
:param duplicate: load another instance of the same project file if True; return existing one if false
:return: dictionary with single key describing model's ID
"""
if not duplicate:
existing_model = session.find_param_in_loaded_models('project_file', project_file)
if existing_model is not None:
return existing_model

Christopher Randolph Rhodes
committed
try:
result = {
'model_id': session.load_model(
model_class,
{'project_file': project_file}
)
}
except (FileNotFoundError, ParameterExpectedError):
raise HTTPException(
status_code=404,
detail=f'Could not load project file {project_file}',
)
return result

Christopher Randolph Rhodes
committed
@app.put('/models/ilastik/pixel_classification/load/')

Christopher Randolph Rhodes
committed
def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate)

Christopher Randolph Rhodes
committed
@app.put('/models/ilastik/object_classification/load/')

Christopher Randolph Rhodes
committed
def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate)

Christopher Randolph Rhodes
committed
def validate_workflow_inputs(model_ids, inpaths):
for mid in model_ids:
if mid not in session.describe_loaded_models().keys():
raise HTTPException(
status_code=409,
detail=f'Model {mid} has not been loaded'
)
for inpa in inpaths:
if not inpa.exists():
raise HTTPException(
status_code=404,
detail=f'Could not find file:\n{inpa}'
)
@app.put('/infer/from_image_file')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
inpath = session.paths['inbound_images'] / input_filename

Christopher Randolph Rhodes
committed
validate_workflow_inputs([model_id], [inpath])

Christopher Randolph Rhodes
committed
record = infer_image_to_image(

Christopher Randolph Rhodes
committed
session.models[model_id]['object'],
session.paths['outbound_images'],
channel=channel,
)

Christopher Randolph Rhodes
committed
session.record_workflow_run(record)

Christopher Randolph Rhodes
committed
return record
@app.put('/models/ilastik/pixel_then_object_classification/infer')
def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict:
inpath = session.paths['inbound_images'] / input_filename
validate_workflow_inputs([px_model_id, ob_model_id], [inpath])
try:
record = infer_px_then_ob_model(
inpath,
session.models[px_model_id]['object'],
session.models[ob_model_id]['object'],
session.paths['outbound_images'],
channel=channel
)
except AssertionError:
raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}')
return record