from pathlib import Path

import numpy as np

from model_server.accessors import generate_file_accessor, InMemoryDataAccessor

class MonoPatchStack(InMemoryDataAccessor):

    def __init__(self, data):
        """
        A sequence of n monochrome images of the same size
        :param data: either np.ndarray of dimensions YXn, or a list of np.ndarrays of size YX
        """

        if isinstance(data, np.ndarray):
            if data.ndim == 3:  # interpret as YXZ
                self._data = np.expand_dims(data, 2)
            elif data.ndim == 4:  # interpret as a copy another patch stack
                self._data = data
            else:
                raise InvalidDataForPatchStackError()
        elif isinstance(data, list): # list of YX patches
            if len(data) == 0:
                self._data = np.ndarray([0, 0, 0, 0], dtype='uin9')
            elif len(data) == 1:
                self._data = np.expand_dims(
                    np.array(
                        data[0].squeeze()
                    ),
                    (2, 3)
                )
            else:
                nda = np.array(data).squeeze()
                assert nda.ndim == 3
                self._data = np.expand_dims(
                    np.moveaxis(
                        nda,
                        [1, 2, 0],
                        [0, 1, 2]),
                    2
                )
        else:
            raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}')

    def make_tczyx(self):
        assert self.chroma == 1
        tyx = np.moveaxis(
            self.data[:, :, 0, :], # YX(C)Z
            [2, 0, 1],
            [0, 1, 2]
        )
        return np.expand_dims(tyx, (1, 2))

    @property
    def count(self):
        return self.nz

    def iat(self, i):
        return self.data[:, :, 0, i]

    def iat_yxcz(self, i):
        return np.expand_dims(self.iat(i), (2, 3))

    def get_list(self):
        n = self.nz
        return [self.data[:, :, 0, zi] for zi in range(0, n)]


class MonoPatchStackFromFile(MonoPatchStack):
    def __init__(self, fpath):
        if not Path(fpath).exists():
            raise FileNotFoundError(f'Could not find {fpath}')
        self.file_acc = generate_file_accessor(fpath)
        super().__init__(self.file_acc.data[:, :, 0, :])

    @property
    def fpath(self):
        return self.file_acc.fpath

class Multichannel3dPatchStack(InMemoryDataAccessor):

    def __init__(self, data):
        """
        A sequence of n (generally) color 3D images of the same size
        :param data: a list of np.ndarrays of size YXCZ
        """

        if isinstance(data, list):  # list of YXCZ patches
            nda = np.zeros((len(data), *np.array([e.shape for e in data]).max(axis=0)), dtype=data[0].dtype)
            for i in range(0, len(data)):
                nzi = data[i].shape[-1]
                nda[i, :, :, :, 0:nzi] = data[i]
            assert nda.ndim == 5
            # self._data = np.moveaxis( # pos-YXCZ
            #         nda,
            #         [0, 1, 2, 0, 3],
            #         [0, 1, 2, 3]
            # )
            self._data = nda
        else:
            raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}')

    def iat(self, i):
        return self.data[i, :, :, :, :]

    def iat_yxcz(self, i):
        return self.iat(i)

    @property
    def count(self):
        return self.shape_dict['P']

    @property
    def data(self):
        """
        Return data as 5d with axes in order of pos, Y, X, C, Z
        :return: np.ndarray
        """
        return self._data

    @property
    def shape(self):
        return self._data.shape

    @property
    def shape_dict(self):
        return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))

class Error(Exception):
    pass

class InvalidDataForPatchStackError(Error):
    pass

class FileNotFoundError(Error):
    pass