from collections import OrderedDict
from pathlib import Path
from time import perf_counter
from typing import List, Union

from fastapi import HTTPException
from numba.scripts.generate_lower_listing import description
from pydantic import BaseModel, Field, root_validator

from ..accessors import GenericImageDataAccessor, InMemoryDataAccessor
from ..roiset import RoiSet
from ..session import session, AccessorIdError


class PipelineParams(BaseModel):
    schedule: bool = Field(False, description='Schedule as a task instead of running immediately')
    keep_interm: bool = Field(False, description='Keep accessors to intermediate images in session')
    api: bool = Field(True, description='Validate parameters against server session and map HTTP errors if True')

    @root_validator(pre=False)
    def models_are_loaded(cls, dd):
        for k, v in dd.items():
            if dd['api'] and k.endswith('model_id') and v is not None:
                if v not in session.describe_loaded_models().keys():
                    raise HTTPException(status_code=409, detail=f'Model with {k} = {v} has not been loaded')
        return dd

    @root_validator(pre=False)
    def accessors_are_loaded(cls, dd):
        for k, v in dd.items():
            if dd['api'] and k.endswith('accessor_id'):
                try:
                    info = session.get_accessor_info(v)
                except AccessorIdError as e:
                    raise HTTPException(status_code=409, detail=str(e))
                if not info['loaded']:
                    raise HTTPException(status_code=409, detail=f'Accessor with {k} = {v} has not been loaded')
        return dd


class PipelineRecord(BaseModel):
    output_accessor_id: str
    interm_accessor_ids: Union[List[str], None]
    success: bool
    timer: dict
    roiset_id: Union[str, None] = None


class PipelineQueueRecord(BaseModel):
    task_id: str

def call_pipeline(func, p: PipelineParams) -> Union[PipelineRecord, PipelineQueueRecord]:
    # instead of running right away, schedule pipeline as a task
    if p.schedule:
        p.schedule = False
        task_id = session.queue.add_task(
            lambda x: call_pipeline(func, x),
            p
        )
        return PipelineQueueRecord(task_id=task_id)

    # match accessor IDs to loaded accessor objects
    accessors_in = {}

    for k, v in p.dict().items():
        if k.endswith('accessor_id'):
            accessors_in[k.split('accessor_id')[0]] = session.get_accessor(v, pop=True)
    if len(accessors_in) == 0:
        raise NoAccessorsFoundError('Expecting as least one valid accessor to run pipeline')

    # use first validated accessor ID to name log file entries and derived accessors
    input_description = [p.dict()[pk] for pk in p.dict().keys() if pk.endswith('accessor_id')][0]

    models = {}
    for k, v in p.dict().items():
        if k.endswith('model_id') and v is not None:
            models[k.split('model_id')[0]] = session.models[v]['object']

    # call the actual pipeline; expect a single PipelineTrace or a tuple where first element is PipelineTrace
    ret = func(
        accessors_in,
        models,
        **p.dict(),
    )
    if isinstance(ret, PipelineTrace):
        trace = ret
        roiset_id = None
    elif isinstance(ret, tuple) and isinstance(ret[0], PipelineTrace) and isinstance(ret[1], RoiSet):
        trace = ret[0]
        roiset_id = session.add_roiset(ret[1])
    else:
        raise UnexpectedPipelineReturnError(
            f'{func.__name__} returned unexpected value of {type(ret)}'
        )
    session.log_info(f'Completed {func.__name__} on {input_description}.')

    # map intermediate data accessors to accessor IDs
    if p.keep_interm:
        interm_ids = []
        acc_interm = trace.accessors(skip_first=True, skip_last=True).items()
        for i, item in enumerate(acc_interm):
            stk, acc = item
            interm_ids.append(
                session.add_accessor(
                    acc,
                    accessor_id=f'{input_description}_{func.__name__}_step{(i + 1):02d}_{stk}'
                )
            )
    else:
        interm_ids = None

    # map final result to an accessor ID
    result_id = session.add_accessor(
        trace.last,
        accessor_id=f'{p.accessor_id}_{func.__name__}_result'
    )

    record = PipelineRecord(
        output_accessor_id=result_id,
        interm_accessor_ids=interm_ids,
        success=True,
        timer=trace.times,
        roiset_id=roiset_id,
    )

    return record


class PipelineTrace(OrderedDict):
    tfunc = perf_counter

    def __init__(self, start_acc: GenericImageDataAccessor = None, enforce_accessors=True, allow_overwrite=False):
        """
        A container and timer for data at each stage of a pipeline.
        :param start_acc: (optional) accessor to initialize as 'input' step
        :param enforce_accessors: if True, only allow accessors to be appended as items
        :param allow_overwrite: if True, allow an item to be overwritten
        """
        self.enforce_accessors = enforce_accessors
        self.allow_overwrite = allow_overwrite
        self.last_time = self.tfunc()
        self.markers = None
        self.timer = OrderedDict()
        super().__init__()
        if start_acc is not None:
            self['input'] = start_acc

    def __setitem__(self, key, value: GenericImageDataAccessor):
        if isinstance(value, GenericImageDataAccessor):
            acc = value
        else:
            if self.enforce_accessors:
                raise NoAccessorsFoundError(f'Pipeline trace expects data accessor type')
            else:
                acc = InMemoryDataAccessor(value)
        if not self.allow_overwrite and key in self.keys():
            raise KeyAlreadyExists(f'key {key} already exists in pipeline trace')
        self.timer.__setitem__(key, self.tfunc() - self.last_time)
        self.last_time = self.tfunc()
        return super().__setitem__(key, acc)

    def append(self, tr, skip_first=True):
        new_tr = self.copy()
        for k, v in tr.items():
            if skip_first and v == tr.first:
                continue
            dt = tr.timer[k]
            if k == 'input':
                k = 'appended_input'
            if not self.allow_overwrite and k in self.keys():
                raise KeyAlreadyExists(f'Trying to append trace with key {k} that already exists')
            new_tr.__setitem__(k, v)
            new_tr.timer.__setitem__(k, dt)
        new_tr.last_time = self.tfunc()
        return new_tr

    @property
    def times(self):
        """
        Return an ordered dictionary of incremental times for each item that is appended
        """
        return {k: self.timer[k] for k in self.keys()}

    @property
    def first(self):
        """
        Return first item
        :return:
        """
        return list(self.values())[0]

    @property
    def last(self):
        """
        Return most recently appended item
        :return:
        """
        return list(self.values())[-1]

    def accessors(self, skip_first=True, skip_last=True) -> dict:
        """
        Return subset ordered dictionary that guarantees items are accessors
        :param skip_first: if True, exclude first item in trace
        :param skip_last: if False, exclude last item in trace
        :return: dictionary of accessors that meet input criteria
        """
        res = OrderedDict()
        for i, item in enumerate(self.items()):
            k, v = item
            if not isinstance(v, GenericImageDataAccessor):
                continue
            if skip_first and k == list(self.keys())[0]:
                continue
            if skip_last and k == list(self.keys())[-1]:
                continue
            res[k] = v
        return res

    def write_interm(
            self,
            where: Path,
            prefix: str = 'interm',
            skip_first=True,
            skip_last=True,
            debug=False
    ) -> List[Path]:
        """
        Write accessor data to TIF files under specified path
        :param where: directory in which to write image files
        :param prefix: (optional) file prefix
        :param skip_first: if True, do not write first item in trace
        :param skip_last: if False, do not write last item in trace
        :param debug: if True, report destination filepaths but do not write files
        :return: list of destination filepaths
        """
        paths = []
        accessors = self.accessors(skip_first=skip_first, skip_last=skip_last)
        for i, item in enumerate(accessors.items()):
            k, v = item
            fp = where / f'{prefix}_{i:02d}_{k}.tif'
            paths.append(fp)
            if not debug:
                v.write(fp)
        return paths


class Error(Exception):
    pass

class IncompatibleModelsError(Error):
    pass

class KeyAlreadyExists(Error):
    pass

class NoAccessorsFoundError(Error):
    pass

class UnexpectedPipelineReturnError(Error):
    pass