from fastapi import APIRouter, HTTPException

from model_server.session import Session
from model_server.validators import  validate_workflow_inputs

from extensions.ilastik import models as ilm
from model_server.models import ParameterExpectedError
from extensions.ilastik.workflows import infer_px_then_ob_model

router = APIRouter(
    prefix='/ilastik',
    tags=['ilastik'],
)

session = Session()

def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplicate=True) -> dict:
    """
    Load an ilastik model of a given class and project filename.
    :param model_class:
    :param project_file: (*.ilp) ilastik project filename
    :param duplicate: load another instance of the same project file if True; return existing one if false
    :return: dict containing model's ID
    """
    if not duplicate:
        existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
        if existing_model_id is not None:
            return {'model_id': existing_model_id}
    try:
        result = session.load_model(model_class, {'project_file': project_file})
    except (FileNotFoundError, ParameterExpectedError):
        raise HTTPException(
            status_code=404,
            detail=f'Could not load project file {project_file}',
        )
    return {'model_id': result}

@router.put('/seg/load/')
def load_px_model(project_file: str, duplicate: bool = True) -> dict:
    return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate)

@router.put('/pxmap_to_obj/load/')
def load_pxmap_to_obj_model(project_file: str, duplicate: bool = True) -> dict:
    return load_ilastik_model(ilm.IlastikObjectClassifierFromPixelPredictionsModel, project_file, duplicate=duplicate)

@router.put('/seg_to_obj/load/')
def load_seg_to_obj_model(project_file: str, duplicate: bool = True) -> dict:
    return load_ilastik_model(ilm.IlastikObjectClassifierFromSegmentationModel, project_file, duplicate=duplicate)

@router.put('/pixel_then_object_classification/infer')
def infer_px_then_ob_maps(px_model_id: str, ob_model_id: str, input_filename: str, channel: int = None) -> dict:
    inpath = session.paths['inbound_images'] / input_filename
    validate_workflow_inputs([px_model_id, ob_model_id], [inpath])
    try:
        record = infer_px_then_ob_model(
            inpath,
            session.models[px_model_id]['object'],
            session.models[ob_model_id]['object'],
            session.paths['outbound_images'],
            channel=channel
        )
    except AssertionError:
        raise HTTPException(f'Incompatible models {px_model_id} and/or {ob_model_id}')
    return record