import logging
import os

from pathlib import Path
from pydantic import BaseModel
from time import strftime, localtime
from typing import Union

import pandas as pd

import model_server.conf.defaults
from model_server.base.models import Model

logger = logging.getLogger(__name__)


class Singleton(type):
    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]

class CsvTable(object):
    def __init__(self, fpath: Path):
        self.path = fpath
        self.empty = True

    def append(self, coords: dict, data: pd.DataFrame) -> bool:
        assert isinstance(data, pd.DataFrame)
        for c in reversed(coords.keys()):
            data.insert(0, c, coords[c])
        if self.empty:
            data.to_csv(self.path, index=False, mode='w', header=True)
        else:
            data.to_csv(self.path, index=False, mode='a', header=False)
        self.empty = False
        return True

class Session(object, metaclass=Singleton):
    """
    Singleton class for a server session that persists data between API calls
    """

    log_format = '%(asctime)s - %(levelname)s - %(message)s'

    def __init__(self, root: str = None):
        print('Initializing session')
        self.models = {} # model_id : model object
        self.paths = self.make_paths(root)

        self.logfile = self.paths['logs'] / f'session.log'
        logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format)

        self.log_info('Initialized session')
        self.tables = {}

    def write_to_table(self, name: str, coords: dict, data: pd.DataFrame):
        """
        Write data to a named data table, initializing if it does not yet exist.
        :param name: name of the table to persist through session
        :param coords: dictionary of coordinates to associate with all rows in this method call
        :param data: DataFrame containing data
        :return: True if successful
        """
        try:
            if name in self.tables.keys():
                table = self.tables.get(name)
            else:
                table = CsvTable(self.paths['tables'] / (name + '.csv'))
                self.tables[name] = table
        except Exception:
            raise CouldNotCreateTable(f'Unable to create table named {name}')

        try:
            table.append(coords, data)
            return True
        except Exception:
            raise CouldNotAppendToTable(f'Unable to append data to table named {name}')

    def get_paths(self):
        return self.paths

    def set_data_directory(self, key: str, path: str):
        if not key in self.paths.keys():
            raise InvalidPathError(f'No such path {key}')
        if not Path(path).exists():
            raise InvalidPathError(f'Could not find {path}')
        self.paths[key] = Path(path)

    @staticmethod
    def make_paths(root: str = None) -> dict:
        """
        Set paths where images, logs, etc. are located in this session
        :param root: absolute path to top-level directory
        :return: dictionary of session paths
        """
        if root is None:
            root_path = Path(model_server.conf.defaults.root)
        else:
            root_path = Path(root)
        sid = Session.create_session_id(root_path)
        paths = {'root': root_path}
        for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']:
            pa = root_path / sid / model_server.conf.defaults.subdirectories[pk]
            paths[pk] = pa
            try:
                pa.mkdir(parents=True, exist_ok=True)
            except Exception:
                raise CouldNotCreateDirectory(f'Could not create directory: {pa}')
        return paths

    @staticmethod
    def create_session_id(look_where: Path) -> str:
        """
        Autogenerate a session ID by incrementing from a list of log files.
        """
        yyyymmdd = strftime('%Y%m%d', localtime())
        idx = 0
        while os.path.exists(look_where / f'{yyyymmdd}-{idx:04d}'):
            idx += 1
        return f'{yyyymmdd}-{idx:04d}'

    def get_log_data(self) -> list:
        log = []
        with open(self.logfile, 'r') as fh:
            for line in fh:
                k = ['datatime', 'level', 'message']
                v = line.strip().split(' - ')[0:3]
                log.insert(0, dict(zip(k, v)))
        return log

    def log_info(self, msg):
        logger.info(msg)

    def log_warning(self, msg):
        logger.warning(msg)

    def log_error(self, msg):
        logger.error(msg)

    def load_model(self, ModelClass: Model, params: Union[BaseModel, None]) -> dict:
        """
        Load an instance of a given model class and attach to this session's model registry
        :param ModelClass: subclass of Model
        :param params: optional parameters that are passed to the model's construct
        :return: model_id of loaded model
        """
        mi = ModelClass(params=params)
        assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
        ii = 0

        def mid(i):
            return f'{ModelClass.__name__}_{i:02d}'

        while mid(ii) in self.models.keys():
            ii += 1

        key = mid(ii)
        self.models[key] = {
            'object': mi,
            'params': mi.params
        }
        self.log_info(f'Loaded model {key}')
        return key

    def describe_loaded_models(self) -> dict:
        return {
            k: {
                'class': self.models[k]['object'].__class__.__name__,
                'params': self.models[k]['params'],
            }
            for k in self.models.keys()
        }

    def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> str:
        """
        Returns model_id of first model where key and value match with .params field, or None
        :param is_path: uses platform-independent path comparison if True
        """

        models = self.describe_loaded_models()
        for mid, det in models.items():
            if is_path:
                if Path(det.get('params').get(key)) == Path(value):
                    return mid
            else:
                if det.get('params').get(key) == value:
                    return mid
        return None

    def restart(self, **kwargs):
        self.__init__(**kwargs)

class Error(Exception):
    pass

class InferenceRecordError(Error):
    pass

class CouldNotInstantiateModelError(Error):
    pass

class CouldNotCreateDirectory(Error):
    pass

class CouldNotCreateTable(Error):
    pass

class CouldNotAppendToTable(Error):
    pass

class InvalidPathError(Error):
    pass