from pathlib import Path
from typing import Dict, List, Union
import uuid

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field

from .accessors import generate_file_accessor, generate_multiposition_file_accessors
from .models import BinaryThresholdSegmentationModel
from .pipelines.shared import PipelineRecord
from .roiset import IntensityThresholdInstanceMaskSegmentationModel
from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError

app = FastAPI(debug=True)

from .pipelines.router import router
app.include_router(router)


@app.on_event("startup")
def startup():
    pass


@app.get('/favicon.ico')
async def favicon():
    return None


@app.get('/')
def read_root():
    return {'success': True}


@app.get('/paths/session')
def get_top_session_path():
    return session.get_paths()['session']


@app.get('/paths/inbound')
def get_inbound_path():
    return session.get_paths()['inbound_images']


@app.get('/paths/outbound')
def get_outbound_path():
    return session.get_paths()['outbound_images']


@app.get('/paths')
def list_session_paths():
    return session.get_paths()


@app.get('/status')
def show_session_status():
    return {
        'status': 'running',
        'memory': session.get_mem(),
        'models': session.describe_loaded_models(),
        'paths': session.get_paths(),
        'accessors': session.list_accessors(),
        'tasks': session.tasks.list_tasks(),
    }


def _change_path(key, path, touch=False) -> Union[str, None]:
    try:
        if touch:
            with open(session.paths[key] / f'{key}.path', 'w') as fh:
                fh.write(path.__str__())
        session.set_data_directory(key, path)
        if touch:
            uid = str(uuid.uuid4())
            with open(Path(path) / 'svlt.touch', 'w') as fh:
                fh.write(uid)
            return uid
    except InvalidPathError as e:
        raise HTTPException(404, f'Did not find valid folder at: {path}')


@app.put('/paths/watch_input')
def watch_input_path(path: str, touch: bool = False) -> Union[str, None]:
    return _change_path('inbound_images', path, touch=touch)


@app.put('/paths/watch_output')
def watch_output_path(path: str, touch: bool = False) -> Union[str, None]:
    return _change_path('outbound_images', path, touch=touch)


@app.put('/paths/watch_conf')
def watch_output_path(path: str, touch: bool = False) -> Union[str, None]:
    return _change_path('conf', path, touch=touch)


@app.get('/session/restart')
def restart_session() -> dict:
    session.restart()
    return session.describe_loaded_models()


@app.get('/session/logs')
def list_session_log() -> list:
    return session.get_log_data()


@app.get('/session/errors')
def get_errors() -> list:
    return [li for li in session.get_log_data() if li['level'] == 'ERROR']


@app.get('/session/mem')
def get_mem() -> dict:
    return session.get_mem()


@app.get('/models')
def list_active_models():
    return session.describe_loaded_models()


class BinaryThresholdSegmentationParams(BaseModel):
    tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation')


@app.put('/models/seg/threshold/load/')
def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
    result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p)
    session.log_info(f'Loaded binary threshold segmentation model {result}')
    return {'model_id': result}


@app.put('/models/classify/threshold/load')
def load_intensity_threshold_instance_segmentation_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
    result = session.load_model(IntensityThresholdInstanceMaskSegmentationModel, key=model_id, params=p)
    session.log_info(f'Loaded permissive instance segmentation model {result}')
    return {'model_id': result}


@app.get('/accessors')
def list_accessors() -> Dict:
    return session.list_accessors()

@app.get('/accessors/loaded')
def list_loaded_accessors() -> Dict:
    return {k: d for k, d in session.list_accessors().items() if d['loaded']}

def _session_accessor(func, acc_id):
    try:
        return func(acc_id)
    except AccessorIdError as e:
        raise HTTPException(404, f'Did not find accessor with ID {acc_id}')


@app.get('/accessors/get/{accessor_id}')
def get_accessor(accessor_id: str):
    return _session_accessor(session.get_accessor_info, accessor_id)


@app.get('/accessors/delete/{accessor_id}')
def delete_accessor(accessor_id: str):
    if accessor_id == '*':
        return session.del_all_accessors()
    else:
        return _session_accessor(session.del_accessor, accessor_id)

@app.put('/accessors/read_from_file/{filename}')
def read_accessor_from_file(filename: str, lazy: bool = False) -> str:
    fp = session.paths['inbound_images'] / filename
    if not fp.exists():
        raise HTTPException(status_code=404, detail=f'Could not find file:\n{filename}')
    acc = generate_file_accessor(fp, lazy=lazy)
    return session.add_accessor(acc)

@app.put('/accessors/write_to_file/{accessor_id}')
def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None, pop: bool = True) -> str:
    try:
        return session.write_accessor(accessor_id, filename, pop=pop)
    except AccessorIdError as e:
        raise HTTPException(404, f'Did not find accessor with ID {accessor_id}')
    except WriteAccessorError as e:
        raise HTTPException(409, str(e))

@app.put('/accessors/read_multiposition_file/{filename}')
def read_multiposition_from_file(filename: str, lazy: bool = False) -> List[str]:
    fp = session.paths['inbound_images'] / filename
    if not fp.exists():
        raise HTTPException(status_code=404, detail=f'Could not find file:\n{filename}')
    multipos = generate_multiposition_file_accessors(fp, lazy=lazy)
    acc_list = multipos.get_accessors()
    return [session.add_accessor(acc) for acc in acc_list]

class TaskInfo(BaseModel):
    module: str
    params: dict
    func_str: str
    status: str
    target: Union[str, None]
    error: Union[str, None]
    result: Union[Dict, None]

@app.put('/tasks/run_all')
def run_all_tasks(write_all: bool = False) -> List[PipelineRecord]:
    resl = []
    while task_id := session.tasks.next_waiting:
        resl.append(session.tasks.run_task(task_id, write_all=write_all))
    return resl

@app.get('/tasks/ready')
def list_waiting_tasks_by_status() -> Dict[str, TaskInfo]:
    return session.tasks.list_tasks(status='READY')

@app.get('/tasks/waiting')
def list_waiting_tasks_by_status() -> Dict[str, TaskInfo]:
    return session.tasks.list_tasks(status='WAITING')

@app.get('/tasks/by_status')
def list_tasks_by_status() -> Dict[str, Dict[str, TaskInfo]]:
    return {st: session.tasks.list_tasks(status=st) for st in session.tasks.status_codes.values()}

@app.get('/tasks/next')
def get_next_waiting_text() -> Union[str, None]:
    return session.tasks.next_waiting

@app.put('/tasks/delete_accessors/{task_id}')
def delete_accessors_by_parent_task(task_id: str) -> list[str]:
    return session.del_all_accessors(parent_task_id=task_id)

@app.put('/tasks/delete_accessors')
def delete_accessors() -> list[str]:
    return session.del_all_accessors()

@app.put('/tasks/run/{task_id}')
def run_task(
        task_id: str,
        write_all: bool = False,
        prefix: str = None,
        output_subdirectory: str = None,
        allow_overwrite: bool = False
) -> PipelineRecord:
    return session.tasks.run_task(
        task_id,
        write_all=write_all,
        prefix=prefix,
        output_subdirectory=output_subdirectory,
        allow_overwrite=allow_overwrite,
    )

@app.get('/tasks/get/{task_id}')
def get_task(task_id: str) -> TaskInfo:
    return session.tasks.get_task_info(task_id)

@app.get('/tasks/get_output_accessor/{task_id}')
def get_output_accessor_id_for_task(task_id: str) -> str:
    return session.tasks.get_output_accessor_id(task_id)

@app.get('/tasks')
def list_tasks() -> Dict[str, TaskInfo]:
    return session.tasks.list_tasks()


@app.get('/phenobase/bounding_box')
def get_phenobase_bounding_boxes() -> list:
    return session.phenobase.list_bounding_boxes()