from pathlib import Path
import shutil

import h5py
import numpy as np
import vigra

from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile
from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel


class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
    """
    Wrap ilastik object classification for inputs comprising raw image and binary segmentation masks, both represented
    as time-series images where each frame contains only one object.
    """

    def infer(self, input_acc: MonoPatchStack, segmentation_acc: MonoPatchStack) -> (np.ndarray, dict):
        assert segmentation_acc.is_mask()
        assert input_acc.chroma == 1

        tagged_input_data = vigra.taggedView(input_acc.make_tczyx(), 'tczyx')
        tagged_seg_data = vigra.taggedView(segmentation_acc.make_tczyx(), 'tczyx')

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)  # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'

        # for some reason ilastik scrambles these axes to Z(1)YX(1)
        assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
        yxz = np.moveaxis(
            obmaps[0][:, 0, :, :, 0],
            [1, 2, 0],
            [0, 1, 2]
        )

        assert yxz.shape[0:2] == input_acc.hw
        assert yxz.shape[2] == input_acc.nz
        return MonoPatchStack(data=yxz), {'success': True}


def generate_ilastik_object_classifier(
        template_ilp: Path,
        target_ilp: Path,
        raw_stack: MonoPatchStackFromFile,
        mask_stack: MonoPatchStackFromFile,
        label_stack: MonoPatchStackFromFile,
        label_names: list,
        lane: int = 0,
) -> Path:
    """
    Starting with a template project file, transfer input data and labels to a new project file.
    :param template_ilp: path to existing ilastik object classifier to use as a template
    :param target_ilp: path to new classifier
    :param raw_stack: stack of patches containing raw data
    :param mask_stack: stack of patches containing object masks
    :param label_stack: stack of patches containing object labels
    :param label_names: list of label names
    :param lane: ilastik lane identifier
    :return: path to generated object classifier
    """
    assert mask_stack.shape == raw_stack.shape
    assert label_stack.shape == raw_stack.shape

    new_ilp = shutil.copy(template_ilp, target_ilp)

    accessors = {
        'Raw Data': raw_stack,
        'Segmentation Image': mask_stack,
    }

    # get labels from label image
    labels = []
    for ii in range(0, label_stack.count):
        unique = np.unique(label_stack.iat(ii))
        assert len(unique) >= 2, 'Label image contains more than one non-zero value'
        assert unique[0] == 0, 'Label image does not contain unlabeled background'
        assert unique[-1] < len(label_names) + 1, f'Label ID {unique[-1]} exceeds number of label names: {len(label_names)}'
        labels.append(unique[-1])

    # write to new project file
    with h5py.File(new_ilp, 'r+') as h5:

        for gk in ['Raw Data', 'Segmentation Image']:
            group = f'Input Data/infos/lane{lane:04d}/{gk}'

            # set path to input image files
            del h5[f'{group}/filePath']
            h5[f'{group}/filePath'] = accessors[gk].fpath.name
            assert not Path(h5[f'{group}/filePath'][()].decode()).is_absolute()
            assert h5[f'{group}/filePath'][()] == accessors[gk].fpath.name.encode()
            assert h5[f'{group}/location'][()] == 'FileSystem'.encode()

            # set input nickname
            del h5[f'{group}/nickname']
            h5[f'{group}/nickname'] = accessors[gk].fpath.stem

            # set input shape
            del h5[f'{group}/shape']
            shape_zyx = [accessors[gk].shape_dict[ax] for ax in ['Z', 'Y', 'X']]
            h5[f'{group}/shape'] = np.array(shape_zyx)

        # change key of label names
        if (k := 'ObjectClassification/LabelNames') in h5.keys():
            del h5[k]
        ln = np.array(label_names)
        h5.create_dataset(k, data=ln.astype('O'))

        if (k := 'ObjectClassification/MaxNumObj') in h5.keys():
            del h5[k]
        h5[k] = len(label_names) - 1

        del h5['currentApplet']
        h5['currentApplet'] = 1

        # change object labels
        if (k := f'ObjectClassification/LabelInputs/{lane:04d}') in h5.keys():
            del h5[k]
        lag = h5.create_group(k)
        for zi, la in enumerate(labels):
            lag[f'{zi}'] = np.array([0., float(la)])

        # delete existing classification weights
        if (k := f'ObjectExtraction/RegionFeatures/{lane:04d}') in h5.keys():
            del h5[k]
        if (k := 'ObjectClassification/ClassifierForests') in h5.keys():
            del h5[k]

    return Path(new_ilp)