from fastapi import APIRouter, HTTPException from model_server.session import Session from model_server.validators import validate_workflow_inputs from extensions.ilastik import models as ilm from model_server.models import ParameterExpectedError from extensions.ilastik.workflows import infer_px_then_ob_model router = APIRouter( prefix='/ilastik', tags=['ilastik'], ) session = Session() def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplicate=True) -> dict: """ Load an ilastik model of a given class and project filename. :param model_class: :param project_file: (*.ilp) ilastik project filename :param duplicate: load another instance of the same project file if True; return existing one if false :return: dict containing model's ID """ if not duplicate: existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True) if existing_model_id is not None: return {'model_id': existing_model_id} try: result = session.load_model(model_class, {'project_file': project_file}) except (FileNotFoundError, ParameterExpectedError): raise HTTPException( status_code=404, detail=f'Could not load project file {project_file}', ) return {'model_id': result} @router.put('/seg/load/') def load_px_model(project_file: str, duplicate: bool = True) -> dict: return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate) @router.put('/pxmap_to_obj/load/') def load_pxmap_to_obj_model(project_file: str, duplicate: bool = True) -> dict: return load_ilastik_model(ilm.IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate) @router.put('/seg_to_obj/load/') def load_seg_to_obj_model(project_file: str, duplicate: bool = True) -> dict: return load_ilastik_model(ilm.IlastikObjectClassifierFromSegmentationModel, project_file, duplicate=duplicate) @router.put('/pixel_then_object_classification/infer') def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict: inpath = session.paths['inbound_images'] / input_filename validate_workflow_inputs([px_model_id, ob_model_id], [inpath]) try: record = infer_px_then_ob_model( inpath, session.models[px_model_id]['object'], session.models[ob_model_id]['object'], session.paths['outbound_images'], channel=channel ) except AssertionError: raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}') return record