import itertools
from math import sqrt, floor
from pathlib import Path
from typing import List, Union
from uuid import uuid4

import numpy as np
import pandas as pd
from pydantic import BaseModel
from scipy.stats import moment
from skimage.filters import sobel

from skimage import draw
from skimage.measure import approximate_polygon, find_contours, label, points_in_poly, regionprops, regionprops_table, shannon_entropy
from skimage.morphology import binary_dilation, disk

from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from .models import InstanceSegmentationModel
from .process import get_safe_contours, pad, rescale, resample_to_8bit, make_rgb
from .annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
from .accessors import generate_file_accessor, PatchStack
from .process import mask_largest_object


class PatchParams(BaseModel):
    draw_bounding_box: bool = False
    draw_contour: bool = False
    draw_mask: bool = False
    rescale_clip: float = 0.001
    focus_metric: str = 'max_sobel'
    rgb_overlay_channels: Union[List[Union[int, None]], None] = None
    rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0]
    pad_to: Union[int, None] = 256
    expanded: bool = False


class AnnotatedZStackParams(BaseModel):
    draw_label: bool = False
    channel: Union[int, None] = None


class RoiFilterRange(BaseModel):
    min: float
    max: float


class RoiFilter(BaseModel):
    area: Union[RoiFilterRange, None] = None


class RoiSetMetaParams(BaseModel):
    filters: Union[RoiFilter, None] = None
    expand_box_by: List[int] = [128, 0]


class RoiSetExportParams(BaseModel):
    patches_3d: Union[PatchParams, None] = None
    annotated_patches_2d: Union[PatchParams, None] = None
    patches_2d: Union[PatchParams, None] = None
    annotated_zstacks: Union[AnnotatedZStackParams, None] = None
    object_classes: bool = False
    derived_channels: bool = False


def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor:
    """
    Convert binary segmentation mask into either a 2D or 3D object identities map
    :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions
    :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False
    :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False
    :return: object identities map
    """
    if allow_3d and connect_3d:
        nda_la = label(
            acc_seg_mask.data_xyz,
            connectivity=3,
        ).astype('uint16')
        return InMemoryDataAccessor(np.expand_dims(nda_la, 2))
    elif allow_3d and not connect_3d:
        nla = 0
        la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16')
        for zi in range(0, acc_seg_mask.nz):
            la_2d = label(
                acc_seg_mask.data_xyz[:, :, zi],
                connectivity=2,
            ).astype('uint16')
            la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla
            nla = la_2d.max()
            la_3d[:, :, 0, zi] = la_2d
        return InMemoryDataAccessor(la_3d)
    else:
        return InMemoryDataAccessor(
            label(
                acc_seg_mask.data_xyz.max(axis=-1),
                connectivity=1,
            ).astype('uint16')
        )


def focus_metrics():
    return {
        'max_intensity': lambda x: np.max(x),
        'stdev': lambda x: np.std(x),
        'max_sobel': lambda x: np.max(sobel(x)),
        'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)),
        'entropy': lambda x: shannon_entropy(x),
        'moment': lambda x: moment(x.flatten(), moment=2),
    }


def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
    query_str = 'label > 0'  # always true
    if filters is not None:  # parse filters
        for k, val in filters.dict(exclude_unset=True).items():
            assert k in ('area')
            vmin = val['min']
            vmax = val['max']
            assert vmin >= 0
            query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
    return df.loc[df.query(query_str).index, :]


def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame:
    """
    If passed a single DataFrame, return the subset whose bounding boxes overlap in 3D space.  If passed two DataFrames,
    return the subset where a ROI in the first overlaps a ROI in the second.  May return duplicates entries where a ROI
    overlaps with multiple neighbors.
    :param df1: DataFrame with potentially overlapping bounding boxes
    :param df2: (optional) second DataFrame
    :return DataFrame describing subset of overlapping ROIs
        bbox_overlaps_with: index of ROI that overlaps
        bbox_intersec: pixel area of intersecting region
    """

    def _compare(r0, r1):
        olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0)
        oly = (r0.y0 < r1.y1) and (r0.y1 > r1.y0)
        olz = (r0.zi == r1.zi)
        return olx and oly and olz

    def _intersec(r0, r1):
        return (r0.x1 - r1.x0) * (r0.y1 - r1.y0)

    first = []
    second = []
    intersec = []

    if df2 is not None:
        for pair in itertools.product(df1.index, df2.index):
            if _compare(df1.iloc[pair[0]], df2.iloc[pair[1]]):
                first.append(pair[0])
                second.append(pair[1])
                intersec.append(
                    _intersec(df1.iloc[pair[0]], df2.iloc[pair[1]])
                )
    else:
        for pair in itertools.combinations(df1.index, 2):
            if _compare(df1.iloc[pair[0]], df1.iloc[pair[1]]):
                first.append(pair[0])
                second.append(pair[1])
                first.append(pair[1])
                second.append(pair[0])
                isc = _intersec(df1.iloc[pair[0]], df1.iloc[pair[1]])
                intersec.append(isc)
                intersec.append(isc)

    sdf = df1.iloc[first]
    sdf.loc[:, 'bbox_overlaps_with'] = second
    sdf.loc[:, 'bbox_intersec'] = intersec
    return sdf

# TODO: option to quantify overlap e.g. IOU
def filter_df_overlap_seg(df: pd.DataFrame) -> pd.DataFrame:
    """
    Return subset of DataFrame whose segmentations overlap in 3D space.
    """

    dfbb = filter_df_overlap_bbox(df)

    def _overlap_seg(r):
        roi1 = df.loc[r.name]
        roi2 = df.loc[r.bbox_overlaps_with]
        ex0 = min(roi1.x0, roi2.x0, roi1.x1, roi2.x1)
        ew = max(roi1.x0, roi2.x0, roi1.x1, roi2.x1) - ex0
        ey0 = min(roi1.y0, roi2.y0, roi1.y1, roi2.y1)
        eh = max(roi1.y0, roi2.y0, roi1.y1, roi2.y1) - ey0
        emask = np.zeros((eh, ew), dtype='uint8')
        sl1 = np.s_[(roi1.y0 - ey0): (roi1.y1 - ey0), (roi1.x0 - ex0): (roi1.x1 - ex0)]
        sl2 = np.s_[(roi2.y0 - ey0): (roi2.y1 - ey0), (roi2.x0 - ex0): (roi2.x1 - ex0)]
        emask[sl1] = roi1.binary_mask
        emask[sl2] = emask[sl2] + roi2.binary_mask
        return np.any(emask > 1)

    dfbb['seg_overlaps'] = dfbb.apply(_overlap_seg, axis=1)
    return dfbb

def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
    """
    Build dataframe associate object IDs with summary stats
    :param acc_raw: accessor to raw image data
    :param acc_obj_ids: accessor to map of object IDs
    :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
    # :param deproject: assign object's z-position based on argmax of raw data if True
    :return: pd.DataFrame
    """
    # build dataframe of objects, assign z index to each object

    if acc_obj_ids.nz == 1:  # deproject objects' z-coordinates from argmax of raw image
        df = pd.DataFrame(regionprops_table(
            acc_obj_ids.data_xy,
            intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'),
            properties=('label', 'area', 'intensity_mean', 'bbox')
        )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'})
        df['zi'] = df['intensity_mean'].round().astype('int')

    else:  # objects' z-coordinates come from arg of max count in object identities map
        df = pd.DataFrame(regionprops_table(
            acc_obj_ids.data_xyz,
            properties=('label', 'area', 'bbox')
        )).rename(columns={
            'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1'
        })
        df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax())

    df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by)

    def _make_binary_mask(r):
        acc = InMemoryDataAccessor(acc_obj_ids.data == r.label)
        cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy
        return cropped

    df['binary_mask'] = df.apply(
        _make_binary_mask,
        axis=1,
        result_type='reduce',
    )
    return df


def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame:
    h = sd['Y']
    w = sd['X']
    nz = sd['Z']

    df['h'] = df['y1'] - df['y0']
    df['w'] = df['x1'] - df['x0']
    ebxy, ebz = expand_box_by
    df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0))
    df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h))
    df['ebb_x0'] = (df.x0 - ebxy).apply(lambda x: max(x, 0))
    df['ebb_x1'] = (df.x1 + ebxy).apply(lambda x: min(x, w))
    df['ebb_z0'] = (df.zi - ebz).apply(lambda x: max(x, 0))
    df['ebb_z1'] = (df.zi + ebz).apply(lambda x: min(x, nz))
    df['ebb_h'] = df['ebb_y1'] - df['ebb_y0']
    df['ebb_w'] = df['ebb_x1'] - df['ebb_x0']
    df['ebb_nz'] = df['ebb_z1'] - df['ebb_z0'] + 1

    # compute relative bounding boxes
    df['rel_y0'] = df.y0 - df.ebb_y0
    df['rel_y1'] = df.y1 - df.ebb_y0
    df['rel_x0'] = df.x0 - df.ebb_x0
    df['rel_x1'] = df.x1 - df.ebb_x0

    assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0']))
    assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0']))

    df['slice'] = df.apply(
        lambda r:
        np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)],
        axis=1,
        result_type='reduce',
    )
    df['expanded_slice'] = df.apply(
        lambda r:
        np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1],
        axis=1,
        result_type='reduce',
    )
    df['relative_slice'] = df.apply(
        lambda r:
        np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :],
        axis=1,
        result_type='reduce',
    )
    return df


def safe_add(a, g, b):
    assert a.dtype == b.dtype
    assert a.shape == b.shape
    assert g >= 0.0

    return np.clip(
        a.astype('uint32') + g * b.astype('uint32'),
        0,
        np.iinfo(a.dtype).max
    ).astype(a.dtype)

def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor:
    id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16')
    if 'binary_mask' not in df.columns:
        raise MissingSegmentationError('RoiSet dataframe does not contain segmentation')

    def _label_obj(r):
        sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1]
        mask = np.expand_dims(r.binary_mask, (2, 3))
        id_mask[sl] = id_mask[sl] + r.label * mask

    df.apply(_label_obj, axis=1)
    return InMemoryDataAccessor(id_mask)


class RoiSet(object):

    # TODO: __init__ to take bounding boxes e.g. from obj det model; flag if overlaps are allowed
    def __init__(
            self,
            acc_raw: GenericImageDataAccessor,
            df: pd.DataFrame,
            params: RoiSetMetaParams = RoiSetMetaParams(),
    ):
        """
        A set of regions of interest, referenced by their positions and contours in the YXCZ space of stack acc_raw.
        RoiSet contains their internal state, which may be exported as patches, maps, and other products by export methods.
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param df: dataframe containing at minimum bounding box and segmentation mask information
        :param params: optional arguments that influence the definition and representation of ROIs
        """
        # assert acc_obj_ids.chroma == 1
        # self.acc_obj_ids = acc_obj_ids
        self.acc_raw = acc_raw
        self.accs_derived = []
        self.params = params

        self._df = df
        self.count = len(self._df)
        # self.object_class_maps = {}  # classification results

    def __iter__(self):
        """Expose ROI meta information via the Pandas.DataFrame API"""
        return self._df.itertuples(name='Roi')

    @staticmethod
    def from_object_ids(
            acc_raw: GenericImageDataAccessor,
            acc_obj_ids: GenericImageDataAccessor,
            params: RoiSetMetaParams = RoiSetMetaParams(),
    ):
        """

        :param acc_raw:
        :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity
            labels its membership in a connected object
        :param params:
        :return:
        """
        assert acc_obj_ids.chroma == 1

        df = filter_df(
            make_df_from_object_ids(
                acc_raw, acc_obj_ids, expand_box_by=params.expand_box_by
            ),
            params.filters,
        )

        return RoiSet(acc_raw, df, params)


    # TODO: add a generator for the object detection case
    @staticmethod
    def from_bounding_boxes(
        acc_raw: GenericImageDataAccessor,
        yxhw_list: List,
        params: RoiSetMetaParams = RoiSetMetaParams()
    ):
        df = pd.DataFrame([
            {
                'y0': yxhw[0],
                'y1': yxhw[0] + yxhw[2],
                'x0': yxhw[1],
                'x1': yxhw[1] + yxhw[3],
            } for yxhw in yxhw_list
        ])
        return RoiSet(acc_raw, df, params)


    @staticmethod
    def from_binary_mask(
            acc_raw: GenericImageDataAccessor,
            acc_seg: GenericImageDataAccessor,
            allow_3d=False,
            connect_3d=True,
            params: RoiSetMetaParams = RoiSetMetaParams()
    ):
        """
        Create a RoiSet from a binary segmentation mask (either 2D or 3D)
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param acc_seg: accessor of a binary segmentation mask (mono) of either two or three dimensions
        :param allow_3d: use a 3D map if True; use a 2D map of the mask's maximum intensity project if False
        :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False
        :param params: optional arguments that influence the definition and representation of ROIs
        """
        return RoiSet.from_object_ids(
            acc_raw,
            get_label_ids(
                acc_seg,
                allow_3d=allow_3d,
                connect_3d=connect_3d
            ),
            params
        )

    @staticmethod
    def from_polygons_2d(
            acc_raw,
            polygons: List[np.ndarray],
            params: RoiSetMetaParams = RoiSetMetaParams()
    ):
        """
        Create a RoiSet where objects are defined from a list of polygon coordinates
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param polygons: list of (variable x 2) np.ndarrays describing (x, y) polymer coordinates
        :param params: optional arguments that influence the definition and representation of ROIs
        """
        mask = np.zeros(acc_raw.get_mono(0, mip=True).shape, dtype=bool)
        for p in polygons:
            sl = draw.polygon(p[:, 1], p[:, 0])
            mask[sl] = True
        return RoiSet.from_binary_mask(
            acc_raw,
            InMemoryDataAccessor(mask),
            allow_3d=False,
            connect_3d=False,
            params=params,
        )

    def get_df(self) -> pd.DataFrame:
        return self._df

    def get_slices(self) -> pd.Series:
        return self.get_df()['slice']

    def add_df_col(self, name, se: pd.Series) -> None:
        self._df[name] = se

    def get_patches_acc(self, channels: list = None, **kwargs) -> PatchStack:  # padded, un-annotated 2d patches
        if channels and len(channels) == 1:
            patches_df = self.get_patches(white_channel=channels[0], **kwargs)
        else:
            patches_df = self.get_patches(channels=channels, **kwargs)
        return PatchStack(list(patches_df.patch))

    def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> str:
        annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs))
        fp = where / (prefix + '.tif')
        write_accessor_data_to_file(fp, annotated)
        return (prefix + '.tif')

    def get_zmask(self, mask_type='boxes') -> np.ndarray:
        """
        Return a mask of same dimensionality as raw data

        :param kwargs: variable-length keyword arguments
            mask_type: if 'boxes', zmask is True in each object's complete bounding box; otherwise 'contours'
        """

        assert mask_type in ('contours', 'boxes')
        zi_st = np.zeros(self.acc_raw.shape, dtype='bool')
        lamap = self.acc_obj_ids.data

        # make an object map where label is replaced by focus position in stack and background is -1
        lut = np.zeros(lamap.max() + 1) - 1
        df = self.get_df()
        lut[df.label] = df.zi

        if mask_type == 'contours':
            zi_map = (lut[lamap] + 1.0).astype('int')
            idxs = np.array(zi_map) - 1
            np.put_along_axis(
                zi_st,
                np.expand_dims(idxs, (2, 3)),
                1,
                axis=3
            )

            # change background level from to 0 in final frame
            zi_st[:, :, :, -1][lamap == 0] = 0

        elif mask_type == 'boxes':
            for roi in self:
                zi_st[roi.slice] = True

        return zi_st


    def classify_by(
            self, name: str, channels: list[int],
            object_classification_model: InstanceSegmentationModel,
            derived_channel_functions: list[callable] = None
    ):
        """
        Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing
        specified inputs through an instance segmentation classifier.  Optionally derive additional inputs for object
        classification by passing a raw input channel through one or more functions.

        :param name: name of column to insert
        :param channels: list of nc raw input channels to send to classifier
        :param object_classification_model: InstanceSegmentation model object
        :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and
            that return a single-channel PatchStack accessor of the same shape
        :return: None
        """

        raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None)  # all channels
        if derived_channel_functions is not None:
            mono_data = [raw_acc.get_mono(c).data for c in range(0, raw_acc.chroma)]
            for fcn in derived_channel_functions:
                der = fcn(raw_acc) # returns patch stack
                if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']:
                    raise DerivedChannelError(
                        f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}'
                    )
                self.accs_derived.append(der)

            # combine channels
            data_derived = [acc.data for acc in self.accs_derived]
            input_acc = PatchStack(
                np.concatenate(
                    [*mono_data, *data_derived],
                    axis=raw_acc._ga('C')
                )
            )

        else:
            input_acc = raw_acc

        # do this on a patch basis, i.e. only one object per frame
        obmap_patches = object_classification_model.label_patch_stack(
            input_acc,
            self.get_patch_masks_acc(expanded=False, pad_to=None)
        )

        self._df['classify_by_' + name] = pd.Series(dtype='Int64')

        for i, roi in enumerate(self):
            oc = np.unique(
                mask_largest_object(
                    obmap_patches.iat(i).data
                )
            )[-1]
            self._df.loc[roi.Index, 'classify_by_' + name] = oc

    def get_object_class_map(self, name: str) -> InMemoryDataAccessor:
        """
        For a given classification result, return a map where object IDs are replaced by each object's class
        :param name: name of the classification result, same as passed to RoiSet.classify_by()
        :return: accessor of object class map
        """
        colname = ('classify_by_' + name)
        assert colname in self._df.columns
        obj_ids = self.acc_obj_ids
        om = np.zeros(obj_ids.shape, obj_ids.dtype)

        def _label_object_class(roi):
            om[self.acc_obj_ids.data == roi.label] = roi[colname]
        self._df.apply(_label_object_class, axis=1)

        return InMemoryDataAccessor(om)

    def export_dataframe(self, csv_path: Path) -> str:
        csv_path.parent.mkdir(parents=True, exist_ok=True)
        self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1).to_csv(csv_path, index=False)
        return csv_path.name


    def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> pd.DataFrame:
        patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded).copy()

        def _export_patch_mask(roi):
            patch = InMemoryDataAccessor(roi.patch_mask)
            ext = 'png'
            fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
            write_accessor_data_to_file(where / fname, patch)
            return fname

        patches_df['patch_mask_path'] = patches_df.apply(_export_patch_mask, axis=1)
        return patches_df


    def export_patches(self, where: Path, prefix='patch', **kwargs) -> pd.DataFrame:
        make_3d = kwargs.get('make_3d', False)
        patches_df = self.get_patches(**kwargs).copy()

        def _export_patch(roi):
            patch = InMemoryDataAccessor(roi.patch)
            ext = 'tif' if make_3d or patch.chroma > 3 or kwargs.get('force_tif') else 'png'
            fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'

            if patch.dtype is np.dtype('uint16'):
                resampled = patch.apply(resample_to_8bit)
                write_accessor_data_to_file(where / fname, resampled)
            else:
                write_accessor_data_to_file(where / fname, patch)
            return fname

        patches_df['patch_path'] = patches_df.apply(_export_patch, axis=1)
        return patches_df

    def get_patch_masks(self, pad_to: int = None, expanded: bool = False) -> pd.DataFrame:
        def _make_patch_mask(roi):
            if expanded:
                patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
                patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255
            else:
                patch = (roi.binary_mask * 255).astype('uint8')
            if pad_to:
                patch = pad(patch, pad_to)
            return np.expand_dims(patch, (2, 3))

        dfe = self._df.copy()
        dfe['patch_mask'] = dfe.apply(_make_patch_mask, axis=1)
        return dfe

    def get_patch_masks_acc(self, **kwargs) -> PatchStack:
        return PatchStack(list(self.get_patch_masks(**kwargs).patch_mask))

    def get_patches(
            self,
            rescale_clip: float = 0.0,
            pad_to: int = None,
            make_3d: bool = False,
            focus_metric: str = None,
            rgb_overlay_channels: list = None,
            rgb_overlay_weights: list = [1.0, 1.0, 1.0],
            white_channel: int = None,
            expanded=False,
            **kwargs
    ) -> pd.DataFrame:

        # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
        raw = self.acc_raw
        if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)):
            assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None])
            assert len(rgb_overlay_channels) == 3
            assert len(rgb_overlay_weights) == 3

            if white_channel:
                assert white_channel < raw.chroma
                mono = raw.get_mono(white_channel).data_xyz
                stack = np.stack([mono, mono, mono], axis=2)
            else:
                stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype)

            for ii, ci in enumerate(rgb_overlay_channels):
                if ci is None:
                    continue
                assert isinstance(ci, int)
                assert ci < raw.chroma
                stack[:, :, ii, :] = safe_add(
                    stack[:, :, ii, :],  # either black or grayscale channel
                    rgb_overlay_weights[ii],
                    raw.get_mono(ci).data_xyz
                )
        else:
            if white_channel is not None:  # interpret as just a single channel
                assert white_channel < raw.chroma
                annotate_rgb = False
                for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']:
                    ca = kwargs.get(k)
                    if ca is None:
                        continue
                    assert (ca < raw.chroma)
                    if ca != white_channel:
                        annotate_rgb = True
                        break
                if annotate_rgb:  # make RGB patches anyway to include annotation color
                    mono = raw.get_mono(white_channel).data_xyz
                    stack = np.stack([mono, mono, mono], axis=2)
                else:  # make monochrome patches
                    stack = raw.get_mono(white_channel).data
            elif kwargs.get('channels'):
                stack = raw.get_channels(kwargs['channels']).data
            else:
                stack = raw.data

        def _make_patch(roi):
            if expanded:
                patch3d = stack[roi.expanded_slice]
                subpatch = patch3d[roi.relative_slice]
            else:
                patch3d = stack[roi.slice]
                subpatch = patch3d

            ph, pw, pc, pz = patch3d.shape

            # make a 3d patch
            if make_3d or not expanded:
                patch = patch3d

            # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
            elif focus_metric is not None:
                foc = focus_metrics()[focus_metric]

                patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)

                for ci in range(0, pc):
                    me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)]
                    zif = np.argmax(me)
                    patch[:, :, ci, 0] = patch3d[:, :, ci, zif]

            # make a 2d patch from middle of z-stack
            else:
                zim = floor(pz / 2)
                patch = patch3d[:, :, :, [zim]]

            assert len(patch.shape) == 4

            mask = np.zeros(patch3d.shape[0:2], dtype=bool)
            if expanded:
                mask[roi.relative_slice[0:2]] = roi.binary_mask
            else:
                mask = roi.binary_mask

            if rescale_clip is not None:
                patch = rescale(patch, rescale_clip)

            if kwargs.get('draw_bounding_box') is True and expanded:
                bci = kwargs.get('bounding_box_channel', 0)
                assert bci < 3
                if bci > 0:
                    patch = make_rgb(patch)
                for zi in range(0, patch.shape[3]):
                    patch[:, :, bci, zi] = draw_box_on_patch(
                        patch[:, :, bci, zi],
                        ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)),
                        linewidth=kwargs.get('bounding_box_linewidth', 1)
                    )

            if kwargs.get('draw_mask'):
                mci = kwargs.get('mask_channel', 0)
                for zi in range(0, patch.shape[3]):
                    patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]

            if kwargs.get('draw_contour'):
                mci = kwargs.get('contour_channel', 0)

                for zi in range(0, patch.shape[3]):
                    contours = get_safe_contours(mask)
                    patch[:, :, mci, zi] = draw_contours_on_patch(
                        patch[:, :, mci, zi],
                        contours
                    )

            if pad_to and expanded:
                patch = pad(patch, pad_to)
            return patch

        dfe = self._df.copy()
        dfe['patch'] = dfe.apply(lambda r: _make_patch(r), axis=1)
        return dfe

    def run_exports(self, where: Path, channel, prefix, params: RoiSetExportParams) -> dict:
        """
        Export various representations of ROIs, e.g. patches, annotated stacks, and object maps.
        :param where: path of directory in which to write all export products
        :param channel: color channel of products to export
        :param prefix: prefix of the name of each product's file or subfolder
        :param params: RoiSetExportParams object describing which products to export and with which parameters
        :return: nested dict of Path objects describing the location of export products
        """
        record = {}
        if not self.count:
            return

        for k in params.dict().keys():
            subdir = where / k
            pr = prefix
            kp = params.dict()[k]
            if kp is None:
                continue
            if k == 'patches_3d':
                df_exp = self.export_patches(
                    subdir, white_channel=channel, prefix=pr, make_3d=True, **kp
                )
                record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
            if k == 'annotated_patches_2d':
                df_exp = self.export_patches(
                    subdir, prefix=pr, make_3d=False, white_channel=channel,
                    bounding_box_channel=1, bounding_box_linewidth=2, **kp,
                )
                record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
            if k == 'patches_2d':
                df_exp = self.export_patches(
                    subdir, white_channel=channel, prefix=pr, make_3d=False, **kp
                )
                self._df = self._df.join(df_exp.patch_path.apply(lambda x: str(Path('patches_2d') / x)))
                self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1)
                record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
            if k == 'annotated_zstacks':
                record[k] = str(Path(k) / self.export_annotated_zstack(subdir, prefix=pr, **kp))
            if k == 'object_classes':
                pr = 'classify_by_'
                cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)]
                for n in cnames:
                    fp = subdir / n / (pr + '.tif')
                    write_accessor_data_to_file(fp, self.get_object_class_map(n))
                    record[f'{k}_{n}'] = str(fp)
            if k == 'derived_channels':
                record[k] = []
                for di, dacc in enumerate(self.accs_derived):
                    fp = subdir / f'dc{di:01d}.tif'
                    fp.parent.mkdir(exist_ok=True, parents=True)
                    dacc.export_pyxcz(fp)
                    record[k].append(str(fp))

        # export dataframe and patch masks
        record = {**record, **self.serialize(where, prefix=prefix)}

        return record

    def serialize(self, where: Path, prefix='') -> dict:
        """
        Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks
        :param where: path of directory in which to write files
        :param prefix: (optional) prefix
        :return: nested dict of Path objects describing the locations of export products
        """
        record = {}
        df_exp = self.export_patch_masks(
            where / 'tight_patch_masks',
            prefix=prefix,
            pad_to=None,
            expanded=False
        )
        se_pa = df_exp.patch_mask_path.apply(
            lambda x: str(Path('tight_patch_masks') / x)
        ).rename('tight_patch_masks_path')
        self._df = self._df.join(se_pa)
        df_fn = self.export_dataframe(where / 'dataframe' / (prefix + '.csv'))
        record['dataframe'] = str(Path('dataframe') / df_fn)
        record['tight_patch_masks'] = list(se_pa)
        return record

    def get_polygons(self, poly_threshold=0, dilation_radius=1) -> pd.DataFrame:
        self.coordinates_ = """
        Fit polygons to all object boundaries in the RoiSet
        :param poly_threshold: threshold distance for polygon fit; a smaller number follows sharp features more closely
        :param dilation_radius: radius of binary dilation to apply before fitting polygon
        :return: Series of (variable x 2) np.ndarrays describing (x, y) polymer coordinates
        """

        pad_to = 1

        def _poly_from_mask(roi):
            mask = roi.binary_mask
            if len(mask.shape) != 2:
                raise PatchMaskShapeError(f'Patch mask needs to be two dimensions to fit a polygon')

            # label and fill holes
            labeled = label(mask)
            filled = [rp.image_filled for rp in regionprops(labeled)]
            assert (np.unique(labeled)[-1] == 1) and (len(filled) == 1), 'Cannot fit multiple polygons in a single patch mask'

            closed = binary_dilation(filled[0], footprint=disk(dilation_radius))
            padded = np.pad(closed, pad_to) * 1.0
            all_contours = find_contours(padded)

            nc = len(all_contours)
            for j in range(0, nc):
                if all([points_in_poly(all_contours[k], all_contours[j]).all() for k in range(0, nc)]):
                    contour = all_contours[j]
                    break

            rel_polygon = approximate_polygon(contour[:, [1, 0]], poly_threshold) - [pad_to, pad_to]
            return rel_polygon + [roi.x0, roi.y0]

        return self._df.apply(_poly_from_mask, axis=1)


    @property
    def acc_obj_ids(self):
        return make_object_ids_from_df(self._df, self.acc_raw.shape_dict)

    # TODO: make this work with obj det dataset
    @staticmethod
    def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''):
        """
        Create an RoiSet object from saved files and an image accessor
        :param acc_raw: accessor to image that contains ROIs
        :param where: path to directory containing RoiSet serialization files, namely dataframe.csv and a subdirectory
            named tight_patch_masks
        :param prefix: starting prefix of patch mask filenames
        :return:
        """
        df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']]

        def _read_binary_mask(r):
            ext = 'png'
            fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}'
            try:
                ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname)
                assert ma_acc.chroma == 1 and ma_acc.nz == 1
                mask_data = ma_acc.data_xy / np.iinfo(ma_acc.data.dtype).max
                return mask_data
            except Exception as e:
                raise DeserializeRoiSet(e)
        df['binary_mask'] = df.apply(_read_binary_mask, axis=1)
        id_mask = make_object_ids_from_df(df, acc_raw.shape_dict)
        return RoiSet.from_object_ids(acc_raw, id_mask)



class Error(Exception):
    pass

class DeserializeRoiSet(Error):
    pass

class DerivedChannelError(Error):
    pass

class MissingSegmentationError(Error):
    pass

class PatchMaskShapeError(Error):
    pass