from pathlib import Path
import re
from time import localtime, strftime
from typing import Dict

import numpy as np
import pandas as pd
from skimage.filters import gaussian, sobel

from extensions.ilastik.models import IlastikPixelClassifierModel
from extensions.chaeo.products import export_3d_patches_with_focus_metrics, export_patches_from_zstack
from extensions.chaeo.zmask import build_zmask_from_object_mask
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.workflows import Timer


def export_patch_focus_metrics_from_multichannel_zstack(
        input_zstack_path: str,
        ilastik_project_file: str,
        pxmap_threshold: float,
        pixel_class: int,
        zmask_channel: int,
        patches_channel: int,
        where_output: str,
        mask_type: str = 'boxes',
        zmask_filters: Dict = None,
        zmask_expand_box_by: int = None,
        annotate_focus_metric=None,
        **kwargs,
) -> Dict:

    ti = Timer()
    stack = generate_file_accessor(Path(input_zstack_path))
    fstem = Path(input_zstack_path).stem
    ti.click('file_input')
    assert stack.nz > 1, 'Expecting z-stack'

    # MIP and classify pixels
    mip = InMemoryDataAccessor(
        stack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True)
    )
    px_model = IlastikPixelClassifierModel(
        params={'project_file': Path(ilastik_project_file)}
    )
    pxmap, _ = px_model.infer(mip)
    ti.click('infer_pixel_probability')

    obmask = InMemoryDataAccessor(
        pxmap.data > pxmap_threshold
    )
    ti.click('threshold_pixel_mask')

    # make zmask
    zmask, zmask_meta, df, interm = build_zmask_from_object_mask(
        obmask.get_one_channel_data(pixel_class),
        stack.get_one_channel_data(zmask_channel),
        mask_type=mask_type,
        filters=zmask_filters,
        expand_box_by=zmask_expand_box_by,
    )
    zmask_acc = InMemoryDataAccessor(zmask)
    ti.click('generate_zmasks')

    files = export_3d_patches_with_focus_metrics(
        Path(where_output) / '3d_patches',
        stack.get_one_channel_data(patches_channel),
        zmask_meta,
        prefix=fstem,
        rescale_clip=0.0,
        make_3d=True,
        annotate_focus_metric=annotate_focus_metric,
    )
    ti.click('export_3d_patches')

    files = export_patches_from_zstack(
        Path(where_output) / '2d_patches',
        stack.get_one_channel_data(patches_channel),
        zmask_meta,
        prefix=fstem,
        draw_bounding_box=True,
        rescale_clip=0.0,
        # focus_metric=lambda x: np.max(sobel(x)),
        focus_metric='max_sobel',
        make_3d=False,
    )
    ti.click('export_2d_patches')

    return {
        'pixel_model_id': px_model.model_id,
        'input_filepath': input_zstack_path,
        'number_of_objects': len(zmask_meta),
        'success': True,
        'timer_results': ti.events,
        'dataframe': df,
        'interm': interm,
    }

if __name__ == '__main__':
    where_czi = Path(
        'c:/Users/rhodes/projects/proj0004-marine-photoactivation/data/exp0038/AutoMic/20230906-163415/Selection'
    )

    where_output_root = Path(
        'c:/Users/rhodes/projects/proj0011-plankton-seg/exp0009/output'
    )
    yyyymmdd = strftime('%Y%m%d', localtime())
    idx = 0
    while Path(where_output_root / f'batch-output-{yyyymmdd}-{idx:04d}').exists():
        idx += 1
    where_output = Path(
        where_output_root / f'batch-output-{yyyymmdd}-{idx:04d}'
    )

    csv_args = {'mode': 'w', 'header': True} # when creating file
    px_ilp = Path.home() / 'model_server' / 'ilastik' / 'AF405-bodies_boundaries.ilp'

    for ff in where_czi.iterdir():
        if ff.stem != 'Selection--W0000--P0009-T0001':
            continue

        pattern = 'Selection--W([\d]+)--P([\d]+)-T([\d]+)'
        ma = re.match(pattern, ff.stem)

        print(ff)
        if not ff.suffix.upper() == '.CZI':
            continue
        if int(ma.groups()[1]) > 10: # skip second half of set
            continue

        export_kwargs = {
            'input_zstack_path': (where_czi / ff).__str__(),
            'ilastik_project_file': px_ilp.__str__(),
            'pxmap_threshold': 0.25,
            'pixel_class': 0,
            'zmask_channel': 0,
            'patches_channel': 4,
            'where_output': where_output.__str__(),
            'mask_type': 'boxes',
            'zmask_filters': {'area': (1e3, 1e8)},
            'zmask_expand_box_by': (128, 3),
            'annotate_focus_metric': 'max_sobel'
        }

        result = export_patch_focus_metrics_from_multichannel_zstack(**export_kwargs)

        # parse and record results
        df = result['dataframe']
        df['filename'] = ff.name
        df.to_csv(where_output / 'df_objects.csv', **csv_args)
        pd.DataFrame(result['timer_results'], index=[0]).to_csv(where_output / 'timer_results.csv', **csv_args)
        pd.json_normalize(export_kwargs).to_csv(where_output / 'workflow_params.csv', **csv_args)
        csv_args = {'mode': 'a', 'header': False} # append to CSV from here on

        # export intermediate data if flagged
        for k in result['interm'].keys():
            write_accessor_data_to_file(
                where_output / k / (ff.stem + '.tif'),
                InMemoryDataAccessor(result['interm'][k])
            )