import json
from concurrent import futures

import numpy as np
import h5py
import pandas as pd

from heimdall import view, to_source
from elf.skeleton import skeletonize
from scripts.attributes.cilia_attributes import (compute_centerline,
                                                 get_bb, load_seg,
                                                 make_indexable)


# NOTE the current paths don't look that great.
# probably need to play with the teasar parameters a bit to improve this
def view_centerline(raw, obj, path, compare_skeleton=False):
    path = make_indexable(path)
    cline = np.zeros(obj.shape, dtype='uint32')
    cline[path] = 1

    if compare_skeleton:
        coords, _ = skeletonize(obj)
        coords = make_indexable(coords)
        skel = np.zeros(obj.shape, dtype='uint32')
        skel[coords] = 1

        view(raw, obj.astype('uint32'), cline, skel)
    else:
        view(raw, obj.astype('uint32'), cline)


def check_lens(cilia_ids=None, compare_skeleton=False):
    path = '../data/0.5.1/segmentations/sbem-6dpf-1-whole-segmented-cilia-labels.h5'
    path_raw = '../data/rawdata/sbem-6dpf-1-whole-raw.h5'
    table = '../data/0.5.1/tables/sbem-6dpf-1-whole-segmented-cilia-labels/default.csv'
    table = pd.read_csv(table, sep='\t')
    table.set_index('label_id')

    with open('precomputed_cilia.json') as f:
        skeletons = json.load(f)

    if cilia_ids is None:
        cilia_ids = range(len(table))

    resolution = [.025, .01, .01]
    with h5py.File(path, 'r') as f, h5py.File(path_raw, 'r') as fr:
        ds = f['t00000/s00/0/cells']
        dsr = fr['t00000/s00/0/cells']

        for cid in cilia_ids:
            if cid in (0, 1, 2):
                continue

            print(cid)
            obj_path = skeletons[cid]
            if obj_path is None:
                print("Skipping cilia", cid)
                continue
            print(len(obj_path))

            bb = get_bb(table, cid, resolution)
            raw = dsr[bb]
            obj = ds[bb] == cid
            view_centerline(raw, obj, obj_path, compare_skeleton)


def precompute():
    path = '../data/0.5.1/segmentations/sbem-6dpf-1-whole-segmented-cilia-labels.h5'
    table = '../data/0.5.1/tables/sbem-6dpf-1-whole-segmented-cilia-labels/default.csv'
    table = pd.read_csv(table, sep='\t')
    table.set_index('label_id')

    resolution = [.025, .01, .01]
    with h5py.File(path) as f:
        ds = f['t00000/s00/0/cells']

        def precomp(cid):
            if cid in (0, 1, 2):
                return
            print(cid)
            obj = load_seg(ds, table, cid, resolution)
            if obj.sum() == 0:
                return
            path = compute_centerline(obj, [res * 1000 for res in resolution])
            return path

        n_cilia = len(table)
        with futures.ThreadPoolExecutor(16) as tp:
            tasks = [tp.submit(precomp, cid) for cid in range(n_cilia)]
            # tasks = [tp.submit(precomp, cid) for cid in (3, 4, 5)]
            results = [t.result() for t in tasks]

        with open('precomputed_cilia.json', 'w') as f:
            json.dump(results, f)


def grid_search():
    path = '../data/0.5.1/segmentations/sbem-6dpf-1-whole-segmented-cilia-labels.h5'
    table = '../data/0.5.1/tables/sbem-6dpf-1-whole-segmented-cilia-labels/default.csv'
    table = pd.read_csv(table, sep='\t')
    table.set_index('label_id')

    label_id = 11

    penalty_scales = [1000, 2500, 5000, 10000]
    penalty_exponents = [2, 4, 8, 16]

    resolution = [.025, .01, .01]
    with h5py.File(path) as f:
        ds = f['t00000/s00/0/cells']

        def precomp(cid, penalty_scale, penalty_exponent):
            print("scale:", penalty_scale, "exponent:", penalty_exponent)
            obj = load_seg(ds, table, cid, resolution)
            path = compute_centerline(obj, [res * 1000 for res in resolution],
                                      penalty_scale=penalty_scale, penalty_exponent=penalty_exponent)
            return {'penalty_scale': penalty_scale, 'penalty_exponent': penalty_exponent, 'path': path}

        with futures.ThreadPoolExecutor(16) as tp:
            tasks = [tp.submit(precomp, label_id, penalty_scale, penalty_exponent)
                     for penalty_scale in penalty_scales for penalty_exponent in penalty_exponents]
            results = [t.result() for t in tasks]

        with open('grid_search.json', 'w') as f:
            json.dump(results, f)


def eval_gridsearch():
    with open('grid_search.json') as f:
        results = json.load(f)

    path_raw = '../data/rawdata/sbem-6dpf-1-whole-raw.h5'
    path = '../data/0.5.1/segmentations/sbem-6dpf-1-whole-segmented-cilia-labels.h5'
    table = '../data/0.5.1/tables/sbem-6dpf-1-whole-segmented-cilia-labels/default.csv'
    table = pd.read_csv(table, sep='\t')
    table.set_index('label_id')

    label_id = 11

    resolution = [.025, .01, .01]
    with h5py.File(path, 'r') as f, h5py.File(path_raw, 'r') as fr:
        ds = f['t00000/s00/0/cells']
        dsr = fr['t00000/s00/0/cells']
        bb = get_bb(table, label_id, resolution)

        raw = dsr[bb]
        obj = (ds[bb] == label_id).astype('uint32')

        sources = [to_source(raw, name='raw'), to_source(obj, name='mask')]
        for res in results:
            line = np.zeros_like(obj)
            path = make_indexable(res['path'])
            line[path] = 1
            name = '%i_%i' % (res['penalty_scale'], res['penalty_exponent'])
            sources.append(to_source(line, name=name))

        view(*sources)


if __name__ == '__main__':
    # precompute()
    # grid_search()

    check_lens([11], compare_skeleton=True)
    # eval_gridsearch()