from math import floor, sqrt
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.stats import moment
from skimage.filters import sobel
from skimage.io import imsave
from skimage.measure import find_contours, shannon_entropy
from tifffile import imwrite

from extensions.chaeo.accessors import MonoPatchStack, Multichannel3dPatchStack
from extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.process import pad, rescale, resample_to_8bit

def _make_rgb(zs):
    h, w, c, nz = zs.shape
    assert c <= 3
    outdata = np.zeros((h, w, 3, nz), dtype=zs.dtype)
    outdata[:, :, 0:c, :] = zs[:, :, :, :]
    return outdata

def _focus_metrics():
    return {
        'max_intensity': lambda x: np.max(x),
        'stdev': lambda x: np.std(x),
        'max_sobel': lambda x: np.max(sobel(x)),
        'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)),
        'entropy': lambda x: shannon_entropy(x),
        'moment': lambda x: moment(x.flatten(), moment=2),
    }

def _write_patch_to_file(where, fname, yxcz):
    ext = fname.split('.')[-1].upper()
    where.mkdir(parents=True, exist_ok=True)

    if ext == 'PNG':
        assert yxcz.dtype == 'uint8', f'Invalid data type {yxcz.dtype}'
        assert yxcz.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs'
        assert yxcz.shape[3] == 1, f'Cannot export z-stacks as PNGs'
        if yxcz.shape[2] == 1:
            outdata = yxcz[:, :, 0, 0]
        elif yxcz.shape[2] == 2: # add a blank blue channel
            outdata = _make_rgb(yxcz)
        else: # preserve RGB order
            outdata = yxcz[:, :, :, 0]
        imsave(where / fname, outdata, check_contrast=False)
        return True

    elif ext in ['TIF', 'TIFF']:
        zcyx = np.moveaxis(yxcz, [3, 2, 0, 1], [0, 1, 2, 3])
        imwrite(where / fname, zcyx, imagej=True)
        return True

    else:
        raise Exception(f'Unsupported file extension: {ext}')

def get_patch_masks_from_zmask_meta(
    stack: GenericImageDataAccessor,
    zmask_meta: list,
    pad_to: int = 256,
) -> MonoPatchStack:
    patches = []
    for mi in zmask_meta:
        sl = mi['slice']

        rbb = mi['relative_bounding_box']
        x0 = rbb['x0']
        y0 = rbb['y0']
        x1 = rbb['x1']
        y1 = rbb['y1']

        sp_sl = np.s_[y0: y1, x0: x1, :, :]

        h, w = stack.data[sl].shape[0:2]
        patch = np.zeros((h, w, 1, 1), dtype='uint8')
        patch[sp_sl][:, :, 0, 0] = mi['mask'] * 255

        if pad_to:
            patch = pad(patch, pad_to)

        patches.append(patch)
    return MonoPatchStack(patches)

def export_patch_masks_from_zstack(
    where: Path,
    stack: GenericImageDataAccessor,
    zmask_meta: list,
    pad_to: int = 256,
    prefix='mask',
    **kwargs
):
    patches_acc = get_patch_masks_from_zmask_meta(
        stack,
        zmask_meta,
        pad_to=pad_to,
        **kwargs
    )
    assert len(zmask_meta) == patches_acc.count

    exported = []
    for i in range(0, len(zmask_meta)):
        mi = zmask_meta[i]
        obj = mi['info']
        patch = patches_acc.iat_yxcz(i)
        ext = 'png'
        fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}'
        _write_patch_to_file(where, fname, patch)
        exported.append(fname)
    return exported

def get_patches_from_zmask_meta(
        stack: GenericImageDataAccessor,
        zmask_meta: list,
        rescale_clip: float = 0.0,
        pad_to: int = 256,
        make_3d: bool = False,
        focus_metric: str = None,
        **kwargs
) -> MonoPatchStack:
    patches = []
    for mi in zmask_meta:

        sl = mi['slice']
        rbb = mi['relative_bounding_box']
        idx = mi['df_index']

        x0 = rbb['x0']
        y0 = rbb['y0']
        x1 = rbb['x1']
        y1 = rbb['y1']

        sp_sl = np.s_[y0: y1, x0: x1, :, :]

        patch3d = stack.data[sl]
        ph, pw, pc, pz = patch3d.shape
        subpatch = patch3d[sp_sl]

        # make a 3d patch
        if make_3d:
            patch = patch3d

        # make a 2d patch, find optimal z-position determined by focus_metric function
        elif focus_metric is not None:
            foc = _focus_metrics()[focus_metric]

            patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)

            for ci in range(0, pc):
                me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)]
                zif = np.argmax(me)
                patch[:, :, ci, 0] = patch3d[:, :, ci, zif]

        # make a 2d patch from middle of z-stack
        else:
            zim = floor(pz / 2)
            patch = patch3d[:, :, :, [zim]]

        assert len(patch.shape) == 4
        assert patch.shape[2] == stack.chroma

        if rescale_clip is not None:
            patch = rescale(patch, rescale_clip)

        if kwargs.get('draw_bounding_box') is True:
            bci = kwargs.get('bounding_box_channel', 0)
            assert bci < 3
            if bci > 0:
                patch = _make_rgb(patch)
            for zi in range(0, patch.shape[3]):
                patch[:, :, bci, zi] = draw_box_on_patch(
                    patch[:, :, bci, zi],
                    ((x0, y0), (x1, y1)),
                    linewidth=kwargs.get('bounding_box_linewidth', 1)
                )

        if kwargs.get('draw_mask'):
            mci = kwargs.get('mask_channel', 0)
            mask = np.zeros(patch.shape[0:2], dtype=bool)
            mask[sp_sl[0:2]] = mi['mask']
            for zi in range(0, patch.shape[3]):
                patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]

        if kwargs.get('draw_contour'):
            mci = kwargs.get('contour_channel', 0)
            mask = np.zeros(patch.shape[0:2], dtype=bool)
            mask[sp_sl[0:2]] = mi['mask']

            for zi in range(0, patch.shape[3]):
                patch[:, :, mci, zi] = draw_contours_on_patch(
                    patch[:, :, mci, zi],
                    find_contours(mask)
                )

        if pad_to:
            patch = pad(patch, pad_to)

        patches.append(patch)
    if not make_3d and pc == 1:
        return MonoPatchStack(patches)
    else:
        return Multichannel3dPatchStack(patches)

def export_patches_from_zstack(
        where: Path,
        stack: GenericImageDataAccessor,
        zmask_meta: list,
        rescale_clip: float = 0.0,
        pad_to: int = 256,
        make_3d: bool = False,
        prefix='patch',
        focus_metric: str = None,
        **kwargs
):
    patches_acc = get_patches_from_zmask_meta(
        stack,
        zmask_meta,
        rescale_clip=rescale_clip,
        pad_to=pad_to,
        make_3d=make_3d,
        focus_metric=focus_metric,
        **kwargs
    )
    assert len(zmask_meta) == patches_acc.count

    exported = []
    for i in range(0, len(zmask_meta)):
        mi = zmask_meta[i]
        patch = patches_acc.iat_yxcz(i)
        obj = mi['info']
        idx = mi['df_index']
        ext = 'tif' if make_3d else 'png'
        fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}'

        if patch.dtype is np.dtype('uint16'):
            _write_patch_to_file(where, fname, resample_to_8bit(patch))
        else:
            _write_patch_to_file(where, fname, patch)

        exported.append({
            'df_index': idx,
            'patch_filename': fname,
        })
    return exported

def export_3d_patches_with_focus_metrics(
        where: Path,
        stack: GenericImageDataAccessor,
        zmask_meta: list,
        rescale_clip: float = 0.0,
        pad_to: int = 256,
        prefix='patch',
        **kwargs
):
    """
    Export 3D patches as multi-level z-stacks, along with CSV of various focus methods for each z-position

    :param kwargs:
    annotate_focus_metric: name focus metric to use when drawing bounding box at optimal focus z-position
    :return:
        list of exported files
    """
    assert stack.chroma == 1, 'Expecting monochromatic image data'
    assert stack.nz > 1, 'Expecting z-stack'

    def get_zstack_focus_metrics(zs):
        nz = zs.shape[3]
        me = _focus_metrics()
        dd = {}
        for zi in range(0, nz):
            spf = zs[:, :, :, zi]
            dd[zi] = {k: me[k](spf) for k in me.keys()}
        return dd

    exported = []
    patch_meta = []
    for mi in zmask_meta:
        obj = mi['info']
        sl = mi['slice']
        rbb = mi['relative_bounding_box']
        idx = mi['df_index']

        patch = stack.data[sl]

        assert len(patch.shape) == 4
        assert patch.shape[2] == stack.chroma

        if rescale_clip is not None:
            patch = rescale(patch, rescale_clip)

        # unpack relative bounding box and define subset of patch data
        x0 = rbb['x0']
        y0 = rbb['y0']
        x1 = rbb['x1']
        y1 = rbb['y1']
        sp_sl = np.s_[y0: y1, x0: x1, :, :]
        subpatch = patch[sp_sl]

        # compute focus metrics for all z-levels
        me_dict = get_zstack_focus_metrics(subpatch)
        patch_meta.append({'label': obj.label, 'zi': obj.zi, 'metrics': me_dict})
        me_df = pd.DataFrame(me_dict).T

        # drawing bounding box only on focused slice
        ak = kwargs.get('annotate_focus_metric')
        if ak and ak in me_df.columns:
            zi_foc = me_df.idxmax().to_dict()[ak]
            patch[:, :, 0, zi_foc] = draw_box_on_patch(
                patch[:, :, 0, zi_foc],
                ((x0, y0), (x1, y1)),
            )

        if pad_to:
            patch = pad(patch, pad_to)

        fstem = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}'
        _write_patch_to_file(where, fstem + '.tif', resample_to_8bit(patch))
        me_df.to_csv(where / (fstem + '.csv'))
        exported.append({
            'df_index': idx,
            'patch_filename': fstem + '.tif',
            'focus_metrics_filename': fstem + '.csv',
        })

    return exported

def export_multichannel_patches_from_zstack(
    where: Path,
    stack: GenericImageDataAccessor,
    zmask_meta: list,
    ch_rgb_overlay: tuple = None,
    overlay_gain: tuple = (1.0, 1.0, 1.0),
    ch_white: int = None,
    **kwargs
):
    """
    Export RGB patches where each patch is assignable to a channel of the input stack
    :param ch_rgb_overlay: tuple of integers (R, G, B) that assign a stack channel index to an RGB channel
    :param overlay_gain: optional, tuple of float (R, G, B) multipliers that can be used to balance relative brightness
    :param ch_white: int, index of stack channel that becomes grayscale signal in export patches
    """
    def _safe_add(a, g, b):
        assert a.dtype == b.dtype
        assert a.shape == b.shape
        assert g >= 0.0

        return np.clip(
            a.astype('uint32') + g * b.astype('uint32'),
            0,
            np.iinfo(a.dtype).max
        ).astype(a.dtype)

    idata = stack.data
    if ch_white:
        assert ch_white < stack.chroma
        mdata = idata[:, :, [ch_white, ch_white, ch_white], :]
    else:
        mdata = idata

    if ch_rgb_overlay:
        assert len(ch_rgb_overlay) == 3
        assert len(overlay_gain) == 3
        for ii, ci in enumerate(ch_rgb_overlay):
            if ci is None:
                continue
            assert isinstance(ci, int)
            assert ci < stack.chroma
            mdata[:, :, ii, :] = _safe_add(
                mdata[:, :, ii, :],
                overlay_gain[ii],
                idata[:, :, ci, :]
            )

    mstack = InMemoryDataAccessor(mdata)
    return export_patches_from_zstack(
        where, mstack, zmask_meta, **kwargs
    )