import os
from pathlib import Path

import numpy as np
import vigra

import extensions.ilastik.conf
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel


class IlastikModel(Model):

    def __init__(self, params, autoload=True):
        self.project_file = Path(params['project_file'])
        params['project_file'] = self.project_file.__str__()
        if self.project_file.is_absolute():
            pap = self.project_file
        else:
            pap = extensions.ilastik.conf.paths['project_files'] / self.project_file
        self.project_file_abspath = pap
        if not pap.exists():
            raise FileNotFoundError(f'Project file does not exist: {pap}')
        if 'project_file' not in params or not self.project_file_abspath.exists():
            raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')

        self.shell = None
        super().__init__(autoload, params)

    def load(self):
        from ilastik import app
        from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo

        self.PreloadedArrayDatasetInfo = PreloadedArrayDatasetInfo

        os.environ["LAZYFLOW_THREADS"] = "8"
        os.environ["LAZYFLOW_TOTAL_RAM_MB"] = "24000"

        args = app.parse_args([])
        args.headless = True
        args.project = self.project_file_abspath.__str__()
        shell = app.main(args, init_logging=False)

        if not isinstance(shell.workflow, self.get_workflow()):
            raise ParameterExpectedError(
                f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
            )
        self.shell = shell

        return True


class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
    model_id = 'ilastik_pixel_classification'
    operations = ['segment', ]

    @staticmethod
    def get_workflow():
        from ilastik.workflows import PixelClassificationWorkflow
        return PixelClassificationWorkflow

    def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
            }
        ]
        pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]

        assert len(pxmaps) == 1, 'ilastik generated more than one pixel map'

        yxcz = np.moveaxis(
            pxmaps[0],
            [1, 2, 3, 0],
            [0, 1, 2, 3]
        )
        return InMemoryDataAccessor(data=yxcz), {'success': True}

    def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs):
        pxmap, _ = self.infer(img)
        mask = pxmap.data[:, :, px_class, :] > px_prob_threshold
        return InMemoryDataAccessor(mask)


class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
    model_id = 'ilastik_object_classification_from_segmentation'

    @staticmethod
    def get_workflow():
        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
        return ObjectClassificationWorkflowBinary

    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
        assert segmentation_img.is_mask()
        if segmentation_img.dtype == 'bool':
            seg = 255 * segmentation_img.data.astype('uint8')
            tagged_seg_data = vigra.taggedView(
                255 * segmentation_img.data.astype('uint8'),
                'yxcz'
            )
        else:
            tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'

        yxcz = np.moveaxis(
            obmaps[0],
            [1, 2, 3, 0],
            [0, 1, 2, 3]
        )
        return InMemoryDataAccessor(data=yxcz), {'success': True}

    def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
        super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
        obmap, _ = self.infer(img, mask)
        return obmap


class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel):
    model_id = 'ilastik_object_classification_from_pixel_predictions'

    @staticmethod
    def get_workflow():
        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
        return ObjectClassificationWorkflowPrediction

    def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
        tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'

        yxcz = np.moveaxis(
            obmaps[0],
            [1, 2, 3, 0],
            [0, 1, 2, 3]
        )
        return InMemoryDataAccessor(data=yxcz), {'success': True}


    def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs):
        """
        Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is
        assigned a class.
        :param img: input image
        :param pxmap: map of pixel probabilities
        :param kwargs:
            pixel_classification_channel: channel of pxmap used to segment objects
            pixel_classification_thresold: threshold of pxmap used to segment objects
        :return:
        """
        if not img.shape == pxmap.shape:
            raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
        # TODO: check that pxmap is in-range
        pxch = kwargs.get('pixel_classification_channel', 0)
        pxtr = kwargs('pixel_classification_threshold', 0.5)
        mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
        # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
        obmap, _ = self.infer(img, mask)
        return obmap