from collections import OrderedDict
import itertools
from math import sqrt, floor
from pathlib import Path
from typing import Dict, List, Union
from typing_extensions import Self
from uuid import uuid4

import glasbey
import numpy as np
import pandas as pd
from pydantic import BaseModel, Field
from scipy.stats import moment
from skimage.filters import sobel

from skimage import draw
from skimage.measure import approximate_polygon, find_contours, label, points_in_poly, regionprops, regionprops_table, shannon_entropy
from skimage.morphology import binary_dilation, disk

from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from .models import InstanceMaskSegmentationModel
from .process import get_safe_contours, pad, rescale, make_rgb
from .annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
from .accessors import generate_file_accessor, PatchStack
from .process import mask_largest_object


class PatchParams(BaseModel):
    make_3d: bool = False
    is_patch_mask: bool = False
    white_channel: Union[int, None] = None
    channels: Union[List[int], None] = None
    draw_bounding_box: bool = False
    draw_contour: bool = False
    draw_mask: bool = False
    rescale_clip: Union[float, None] = None
    focus_metric: Union[str, None] = 'max_sobel'
    rgb_overlay_channels: Union[List[Union[int, None]], None] = None
    rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0]
    pad_to: Union[int, None] = 256
    expanded: bool = False
    force_tif: bool = False
    update_focus_zi: bool = Field(
        False,
        description='If generating 2d patches with a 3d focus metric, update the RoiSet patch focus zi accordingly'
    )
    mask_mip: bool = Field(
        False,
        description='If generating 2d patch masks, use the MIP of the 3d patch mask instead of existing zi focus'
    )

class AnnotatedPatchParams(PatchParams):
    bounding_box_channel: int = 1
    bounding_box_linewidth: int = 2

class RoiSetLabelsOverlayParams(BaseModel):
    white_channel: int
    transparency: float = 0.5
    mip: bool = False
    rescale_clip: Union[float, None] = None

class AnnotatedZStackParams(BaseModel):
    draw_label: bool = False
    channel: Union[int, None] = None


class RoiFilterRange(BaseModel):
    min: float
    max: float


class RoiFilter(BaseModel):
    area: Union[RoiFilterRange, None] = None
    diag: Union[RoiFilterRange, None] = None
    min_hw: Union[RoiFilterRange, None] = None


class RoiSetMetaParams(BaseModel):
    filters: Union[RoiFilter, None] = None
    expand_box_by: List[int] = [0, 0]
    deproject_channel: Union[int, None] = None
    deproject_intensity_threshold: float = 0.0


class RoiSetExportParams(BaseModel):
    patches: Union[Dict[str, PatchParams], None] = None
    annotated_zstacks: Union[AnnotatedZStackParams, None] = None
    object_classes: bool = False
    labels_overlay: Union[RoiSetLabelsOverlayParams, None] = None
    derived_channels: bool = False
    write_patches_to_subdirectory: bool = Field(
        False,
        description='Write all patches to a subdirectory with prefix as name'
    )


def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor:
    """
    Convert binary segmentation mask into either a 2D or 3D object identities map
    :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions
    :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False
    :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False
    :return: object identities map
    """
    if allow_3d and connect_3d:
        nda_la = label(
            acc_seg_mask.data_yxz,
            connectivity=3,
        ).astype('uint16')
        return InMemoryDataAccessor(np.expand_dims(nda_la, 2))
    elif allow_3d and not connect_3d:
        nla = 0
        la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16')
        for zi in range(0, acc_seg_mask.nz):
            la_2d = label(
                acc_seg_mask.data_yxz[:, :, zi],
                connectivity=2,
            ).astype('uint16')
            la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla
            nla = la_2d.max()
            la_3d[:, :, 0, zi] = la_2d
        return InMemoryDataAccessor(la_3d)
    else:
        return InMemoryDataAccessor(
            label(
                acc_seg_mask.get_mip().data_yx,
                connectivity=1,
            ).astype('uint16')
        )


def focus_metrics():
    return {
        'max_intensity': lambda x: np.max(x),
        'stdev': lambda x: np.std(x),
        'max_sobel': lambda x: np.max(sobel(x)),
        'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)),
        'entropy': lambda x: shannon_entropy(x),
        'moment': lambda x: moment(x.flatten(), moment=2),
    }


def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
    query_str = 'label > 0'  # always true
    if filters is not None:  # parse filters
        for k, val in filters.dict(exclude_unset=True).items():
            assert k in ('area', 'diag', 'min_hw')
            if val is None:
                continue
            vmin = val['min']
            vmax = val['max']
            assert vmin >= 0
            query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
    return df.loc[df.bounding_box.query(query_str).index, :]


def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame:
    """
    If passed a single DataFrame, return the subset whose bounding boxes overlap in 3D space.  If passed two DataFrames,
    return the subset where a ROI in the first overlaps a ROI in the second.  May return duplicates entries where a ROI
    overlaps with multiple neighbors.
    :param df1: DataFrame with potentially overlapping bounding boxes
    :param df2: (optional) second DataFrame
    :return DataFrame describing subset of overlapping ROIs
        bbox_overlaps_with: index of ROI that overlaps
        bbox_intersec: pixel area of intersecting region
    """

    def _compare(r0, r1):
        olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0)
        oly = (r0.y0 < r1.y1) and (r0.y1 > r1.y0)
        olz = (r0.zi == r1.zi)
        return olx and oly and olz

    def _intersec(r0, r1):
        return (r0.x1 - r1.x0) * (r0.y1 - r1.y0)

    first = []
    second = []
    intersec = []

    if df2 is not None:
        for pair in itertools.product(df1.index, df2.index):
            if _compare(
                    df1.bounding_box.loc[pair[0]],
                    df2.bounding_box.loc[pair[1]]
            ):
                first.append(pair[0])
                second.append(pair[1])
                intersec.append(
                    _intersec(
                        df1.bounding_box.loc[pair[0]],
                        df2.bounding_box.loc[pair[1]]
                    )
                )
    else:
        for pair in itertools.combinations(df1.index, 2):
            if _compare(
                    df1.bounding_box.loc[pair[0]],
                    df1.bounding_box.loc[pair[1]]
            ):
                first.append(pair[0])
                second.append(pair[1])
                first.append(pair[1])
                second.append(pair[0])
                isc = _intersec(
                    df1.bounding_box.loc[pair[0]],
                    df1.bounding_box.loc[pair[1]]
                )
                intersec.append(isc)
                intersec.append(isc)

    sdf = df1.bounding_box.loc[first]
    sdf.loc[:, 'overlaps_with'] = second
    sdf.loc[:, 'bbox_intersec'] = intersec
    return sdf


def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame:
    """
    If passed a single DataFrame, return the subset whose segmentations overlap in 3D space.  If passed two DataFrames,
    return the subset where a ROI in the first overlaps a ROI in the second.  May return duplicates entries where a ROI
    overlaps with multiple neighbors.
    :param df1: DataFrame with potentially overlapping bounding boxes
    :param df2: (optional) second DataFrame
    :return DataFrame describing subset of overlapping ROIs
        seg_overlaps_with: index of ROI that overlaps
        seg_intersec: pixel area of intersecting region
        seg_iou: intersection over union
    """

    dfbb = filter_df_overlap_bbox(df1, df2)

    def _overlap_seg(r):
        roi1 = df1.loc[r.name]
        if df2 is not None:
            roi2 = df2.loc[r.overlaps_with]
        else:
            roi2 = df1.loc[r.overlaps_with]
        bb1 = roi1.bounding_box
        bb2 = roi2.bounding_box
        ex0 = min(bb1.x0, bb2.x0, bb1.x1, bb2.x1)
        ew = max(bb1.x0, bb2.x0, bb1.x1, bb2.x1) - ex0
        ey0 = min(bb1.y0, bb2.y0, bb1.y1, bb2.y1)
        eh = max(bb1.y0, bb2.y0, bb1.y1, bb2.y1) - ey0
        emask = np.zeros((eh, ew), dtype='uint8')
        sl1 = np.s_[(bb1.y0 - ey0): (bb1.y1 - ey0), (bb1.x0 - ex0): (bb1.x1 - ex0)]
        sl2 = np.s_[(bb2.y0 - ey0): (bb2.y1 - ey0), (bb2.x0 - ex0): (bb2.x1 - ex0)]
        emask[sl1] = roi1.masks.binary_mask
        emask[sl2] = emask[sl2] + roi2.masks.binary_mask
        return emask

    emasks = dfbb.apply(_overlap_seg, axis=1)
    dfbb['seg_overlaps'] = emasks.apply(lambda x: np.any(x > 1))
    dfbb['seg_intersec'] = emasks.apply(lambda x: (x == 2).sum())
    dfbb['seg_iou'] = emasks.apply(lambda x: (x == 2).sum() / (x > 0).sum())
    return dfbb


def is_df_3d(df: pd.DataFrame) -> bool:
    return 'z0' in df.bounding_box.columns and 'z1' in df.bounding_box.columns


def make_df_from_object_ids(
        acc_raw,
        acc_obj_ids,
        expand_box_by,
        deproject_channel=None,
        filters=None,
        deproject_intensity_threshold=0.0
) -> pd.DataFrame:
    """
    Build dataframe that associate object IDs with summary stats;
    :param acc_raw: accessor to raw image data
    :param acc_obj_ids: accessor to map of object IDs
    :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
    :param deproject_channel: if objects' z-coordinates are not specified, compute them based on argmax of this channel
    :param deproject_intensity_threshold: when deprojecting, round MIP deprojection_channel to zero if below this
        threshold (as fraction of full range, 0.0 to 1.0)
    :return: pd.DataFrame
    """
    # build dataframe of objects, assign z index to each object

    if acc_obj_ids.nz == 1 and acc_raw.nz > 1:  # apply deprojection

        if deproject_channel is None or deproject_channel >= acc_raw.chroma or deproject_channel < 0:
            if acc_raw.chroma == 1:
                deproject_channel = 0
            else:
                raise NoDeprojectChannelSpecifiedError(
                    f'When labeling objects, either their z-coordinates or a valid deprojection channel are required.'
                )

        mono = acc_raw.get_mono(deproject_channel)
        intensity_weight = mono.get_mip().data_yx.astype('uint16')
        intensity_weight[intensity_weight < (deproject_intensity_threshold * mono.dtype_max)] = 0
        argmax = mono.get_z_argmax().data_yx.astype('uint16')
        zi_map = np.stack([
            intensity_weight,
            argmax * intensity_weight,
        ], axis=-1)

        assert len(zi_map.shape) == 3
        df = pd.DataFrame(regionprops_table(
            acc_obj_ids.data_yx,
            intensity_image=zi_map,
            properties=('label', 'area', 'intensity_mean', 'bbox')
        )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'})

        df['zi'] = (df['intensity_mean-1'] / df['intensity_mean-0']).fillna(0).round().astype('int16')
        df = df.drop(['intensity_mean-0', 'intensity_mean-1'], axis=1)

        def _make_binary_mask(r):
            acc = InMemoryDataAccessor(acc_obj_ids.data == r.name)
            cropped = acc.get_mono(0, mip=True).crop_hw(
                (int(r.y0), int(r.x0), int(r.y1 - r.y0), int(r.x1 - r.x0))
            ).data_yx
            return cropped

    elif acc_obj_ids.nz == 1 and acc_raw.nz == 1:  # purely 2d, no z information in dataframe
        df = pd.DataFrame(regionprops_table(
            acc_obj_ids.data_yx,
            properties=('label', 'area', 'bbox')
        )).rename(columns={
            'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'
        })

        def _make_binary_mask(r):
            acc = InMemoryDataAccessor(acc_obj_ids.data == r.name)
            cropped = acc.get_mono(0).crop_hw(
                (int(r.y0), int(r.x0), int(r.y1 - r.y0), int(r.x1 - r.x0))
            ).data_yx
            return cropped

    else:  # purely 3d: objects' median z-coordinates come from arg of max count in object identities map
        df = pd.DataFrame(regionprops_table(
            acc_obj_ids.data_yxz,
            properties=('label', 'area', 'bbox')
        )).rename(columns={
            'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1'
        })

        def _get_zi_from_label(r):
            r = r.convert_dtypes()
            la = r.name
            crop = acc_obj_ids.crop_hwd((r.y0, r.x0, r.z0, r.y1 - r.y0, r.x1 - r.x0, r.z1 - r.z0))
            rel_argzmax = crop.apply(lambda x: x == la).get_focus_vector().argmax()
            return rel_argzmax + r.z0

        df['zi'] = df.apply(_get_zi_from_label, axis=1, result_type='reduce')
        df['nz'] = df['z1'] - df['z0']

        def _make_binary_mask(r):
            r = r.convert_dtypes()
            la = r.name
            crop = acc_obj_ids.crop_hwd(
                (int(r.y0), int(r.x0), int(r.z0), int(r.y1 - r.y0), int(r.x1 - r.x0), int(r.z1 - r.z0))
            )
            return crop.apply(lambda x: x == la).data_yxz
    df = df.set_index('label')
    insert_level(df, 'bounding_box')
    df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by)
    df_fil = filter_df(df, filters)
    df_fil['masks', 'binary_mask'] = df_fil.bounding_box.apply(
        _make_binary_mask,
        axis=1,
        result_type='reduce',
    )
    return df_fil


def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame:
    h = sd['Y']
    w = sd['X']
    nz = sd['Z']

    bb = 'bounding_box'
    df[bb, 'h'] = df[bb, 'y1'] - df[bb, 'y0']
    df[bb, 'w'] = df[bb, 'x1'] - df[bb, 'x0']
    df[bb, 'diag'] = (df[bb, 'w'] ** 2 + df[bb, 'h'] ** 2).apply(sqrt)
    df[bb, 'min_hw'] = df[[[bb, 'w'], [bb, 'h']]].min(axis=1)

    ebxy, ebz = expand_box_by
    ebb = 'expanded_bounding_box'
    df[ebb, 'ebb_y0'] = (df[bb, 'y0'] - ebxy).apply(lambda x: max(x, 0))
    df[ebb, 'ebb_y1'] = (df[bb, 'y1'] + ebxy).apply(lambda x: min(x, h))
    df[ebb, 'ebb_x0'] = (df[bb, 'x0'] - ebxy).apply(lambda x: max(x, 0))
    df[ebb, 'ebb_x1'] = (df[bb, 'x1'] + ebxy).apply(lambda x: min(x, w))

    # handle based on whether bounding box coordinates are 2d or 3d
    if is_df_3d(df):
        df[ebb, 'ebb_z0'] = (df[bb, 'z0'] - ebz).apply(lambda x: max(x, 0))
        df[ebb, 'ebb_z1'] = (df[bb, 'z1'] + ebz).apply(lambda x: max(x, nz))
    else:
        if 'zi' not in df.bounding_box.columns:
            df[bb, 'zi'] = 0
        df[ebb, 'ebb_z0'] = (df[bb, 'zi'] - ebz).apply(lambda x: max(x, 0))
        df[ebb, 'ebb_z1'] = (df[bb, 'zi'] + ebz).apply(lambda x: min(x, nz))

    df[ebb, 'ebb_h'] = df[ebb, 'ebb_y1'] - df[ebb, 'ebb_y0']
    df[ebb, 'ebb_w'] = df[ebb, 'ebb_x1'] - df[ebb, 'ebb_x0']
    df[ebb, 'ebb_nz'] = df[ebb, 'ebb_z1'] - df[ebb, 'ebb_z0'] + 1

    # compute relative bounding boxes
    rbb = 'relative_bounding_box'
    df[rbb, 'rel_y0'] = df[bb, 'y0'] - df[bb, 'y0']
    df[rbb, 'rel_y1'] = df[bb, 'y1'] - df[bb, 'y0']
    df[rbb, 'rel_x0'] = df[bb, 'x0'] - df[bb, 'x0']
    df[rbb, 'rel_x1'] = df[bb, 'x1'] - df[bb, 'x0']

    assert np.all(df[rbb, 'rel_x1'] <= (df[ebb, 'ebb_x1'] - df[ebb, 'ebb_x0']))
    assert np.all(df[rbb, 'rel_y1'] <= (df[ebb, 'ebb_y1'] - df[ebb, 'ebb_y0']))

    if is_df_3d(df):
        df['slices', 'slice'] = df['bounding_box'].apply(
            lambda r:
            np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.z0): int(r.z1)],
            axis=1,
            result_type='reduce',
        )
    else:
        df['slices', 'slice'] = df['bounding_box'].apply(
            lambda r:
            np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)],
            axis=1,
            result_type='reduce',
        )
    df['slices', 'expanded_slice'] = df['expanded_bounding_box'].apply(
        lambda r:
        np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1],
        axis=1,
        result_type='reduce',
    )
    df['slices', 'relative_slice'] = df['relative_bounding_box'].apply(
        lambda r:
        np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :],
        axis=1,
        result_type='reduce',
    )
    return df


def safe_add(a, g, b):
    assert a.dtype == b.dtype
    assert a.shape == b.shape
    assert g >= 0.0

    return np.clip(
        a.astype('uint32') + g * b.astype('uint32'),
        0,
        np.iinfo(a.dtype).max
    ).astype(a.dtype)


def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor:
    id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16')

    if 'binary_mask' not in df.masks.columns:
        raise MissingSegmentationError('RoiSet dataframe does not contain segmentation')

    if is_df_3d(df):  # use 3d coordinates
        def _label_obj(r):
            bb = r.bounding_box
            sl = np.s_[bb.y0:bb.y1, bb.x0:bb.x1, :, bb.z0:bb.z1]
            mask = np.expand_dims(r.masks.binary_mask, 2)
            id_mask[sl] = id_mask[sl] + r.name * mask
    elif 'zi' in df.bounding_box.columns:
        def _label_obj(r):
            bb = r.bounding_box
            sl = np.s_[bb.y0:bb.y1, bb.x0:bb.x1, :, bb.zi: (bb.zi + 1)]
            mask = np.expand_dims(r.masks.binary_mask, (2, 3))
            id_mask[sl] = id_mask[sl] + r.name * mask
    else:
        def _label_obj(r):
            bb = r.bounding_box
            sl = np.s_[bb.y0:bb.y1, bb.x0:bb.x1, :]
            mask = np.expand_dims(r.masks.binary_mask, (2, 3))
            id_mask[sl] = id_mask[sl] + r.name * mask

    df.apply(_label_obj, axis=1)
    return InMemoryDataAccessor(id_mask)


def insert_level(df: pd.DataFrame, name: str):
    df.columns = pd.MultiIndex.from_product(
        [
            [name],
            list(df.columns.values),
        ],
    )

def read_roiset_df(csv_path: Path) -> pd.DataFrame:
    return pd.read_csv(csv_path, header=[0, 1], index_col=0)


class RoiSet(object):

    def __init__(
            self,
            acc_raw: GenericImageDataAccessor,
            df: pd.DataFrame,
            params: RoiSetMetaParams = RoiSetMetaParams(),
    ):
        """
        A set of regions of interest, referenced by their positions and contours in the YXCZ space of stack acc_raw.
        RoiSet contains their internal state, which may be exported as patches, maps, and other products by export methods.
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param df: dataframe containing at minimum bounding box and segmentation mask information
        :param params: optional arguments that influence the definition and representation of ROIs
        """
        self.acc_raw = acc_raw
        self.accs_derived = []
        self.params = params

        self._df = df
        self.count = len(self._df)

    @classmethod
    def from_object_ids(
            cls,
            acc_raw: GenericImageDataAccessor,
            acc_obj_ids: GenericImageDataAccessor,
            params: RoiSetMetaParams = RoiSetMetaParams(),
    ):
        """
        Create an RoiSet from an object identities map
        :param acc_raw: accessor to a generally multichannel z-stack
        :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity
            labels its membership in a connected object
        :param params: optional arguments that influence the definition and representation of ROIs
        :return: RoiSet object
        """
        assert acc_obj_ids.chroma == 1

        df = make_df_from_object_ids(
            acc_raw, acc_obj_ids,
            expand_box_by=params.expand_box_by,
            deproject_channel=params.deproject_channel,
            deproject_intensity_threshold=params.deproject_intensity_threshold,
            filters=params.filters,
        )

        return cls(acc_raw, df, params)

    @classmethod
    def from_bounding_boxes(
        cls,
        acc_raw: GenericImageDataAccessor,
        bbox_yxhw: List[Dict],
        bbox_zi: Union[List[int], int] = None,
        params: RoiSetMetaParams = RoiSetMetaParams()
    ):
        """
        Create and RoiSet from bounding boxes
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param yxhw_list: list of bounding boxing coordinates [corner X, corner Y, height, width]
        :param params: optional arguments that influence the definition and representation of ROIs
        :return: RoiSet object
        """
        bbox_df = pd.DataFrame(bbox_yxhw)
        if list(bbox_df.columns.str.upper().sort_values()) != ['H', 'W', 'X', 'Y']:
            raise BoundingBoxError(f'Expecting bounding box coordinates Y, X, H, and W, not {list(bbox_df.columns)}')


        # deproject if zi is not specified
        if bbox_zi is None:
            dch = params.deproject_channel
            if dch is None or dch >= acc_raw.chroma or dch < 0:
                if acc_raw.chroma == 1:
                    dch = 0
                else:
                    raise NoDeprojectChannelSpecifiedError(
                        f'When labeling objects, either their z-coordinates or a valid deprojection channel are required.'
                    )
            bbox_df['zi'] = acc_raw.get_mono(dch).get_focus_vector().argmax()
        else:
            bbox_df['zi'] = bbox_zi

        bbox_df['y0'] = bbox_df['y']
        bbox_df['x0'] = bbox_df['x']
        bbox_df['y1'] = bbox_df['y0'] + bbox_df['h']
        bbox_df['x1'] = bbox_df['x0'] + bbox_df['w']
        bbox_df['label'] = bbox_df.index

        bbox_df.set_index('label', inplace=True)
        bbox_df = bbox_df.drop(['x', 'y', 'w', 'h'], axis=1)
        insert_level(bbox_df, 'bounding_box')
        df = df_insert_slices(
            bbox_df,
            acc_raw.shape_dict,
            params.expand_box_by,
        )

        def _make_binary_mask(r):
            return np.ones((int(r.h), int(r.w)), dtype=bool)

        df['masks', 'binary_mask'] = df.bounding_box.apply(
            _make_binary_mask,
            axis=1,
            result_type='reduce',
        )
        return cls(acc_raw, df, params)


    @classmethod
    def from_binary_mask(
            cls,
            acc_raw: GenericImageDataAccessor,
            acc_seg: GenericImageDataAccessor,
            allow_3d=False,
            connect_3d=True,
            params: RoiSetMetaParams = RoiSetMetaParams()
    ):
        """
        Create a RoiSet from a binary segmentation mask (either 2D or 3D)
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param acc_seg: accessor of a binary segmentation mask (mono) of either two or three dimensions
        :param allow_3d: use a 3D map if True; use a 2D map of the mask's maximum intensity project if False
        :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False
        :param params: optional arguments that influence the definition and representation of ROIs
        """
        return cls.from_object_ids(
            acc_raw,
            get_label_ids(
                acc_seg,
                allow_3d=allow_3d,
                connect_3d=connect_3d
            ),
            params
        )

    @classmethod
    def from_polygons_2d(
            cls,
            acc_raw,
            polygons: List[np.ndarray],
            params: RoiSetMetaParams = RoiSetMetaParams()
    ):
        """
        Create a RoiSet where objects are defined from a list of polygon coordinates
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param polygons: list of (variable x 2) np.ndarrays describing (x, y) polymer coordinates
        :param params: optional arguments that influence the definition and representation of ROIs
        """
        mask = np.zeros(acc_raw.get_mono(0, mip=True).shape, dtype=bool)
        for p in polygons:
            sl = draw.polygon(p[:, 1], p[:, 0])
            mask[sl] = True
        return cls.from_binary_mask(
            acc_raw,
            InMemoryDataAccessor(mask),
            allow_3d=False,
            connect_3d=False,
            params=params,
        )

    @property
    def info(self):
        return {
            'raw_shape_dict': self.acc_raw.shape_dict,
            'count': self.count,
            'classify_by': self.classification_columns,
            'df_memory_usage': int(self.df().memory_usage(deep=True).sum())
        }

    def df(self) -> pd.DataFrame:
        return self._df.copy()

    def list_bounding_boxes(self):
        return self.df.bounding_box.reset_index().to_dict(orient='records')

    def get_slices(self) -> pd.Series:
        return self.df()['slices', 'slice']

    def add_df_col(self, name, se: pd.Series) -> None:
        self._df[name] = se

    def get_patches_acc(self, channels: list = None, **kwargs) -> Union[PatchStack, None]:  # padded, un-annotated 2d patches
        if self.count == 0:
            return None
        if channels and len(channels) == 1:
            return PatchStack(list(self.get_patches(white_channel=channels[0], **kwargs)))
        else:
            return PatchStack(list(self.get_patches(channels=channels, **kwargs)))

    def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> str:
        annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs))
        fp = where / (prefix + '.tif')
        write_accessor_data_to_file(fp, annotated)
        return (prefix + '.tif')

    def get_zmask(self, mask_type='boxes') -> np.ndarray:
        """
        Return a mask of same dimensionality as raw data

        :param kwargs: variable-length keyword arguments
            mask_type: if 'boxes', zmask is True in each object's complete bounding box; otherwise 'contours'
        """

        assert mask_type in ('contours', 'boxes')
        zi_st = np.zeros(self.acc_raw.shape, dtype='bool')
        lamap = self.acc_obj_ids.data

        # make an object map where label is replaced by focus position in stack and background is -1
        lut = np.zeros(lamap.max() + 1) - 1
        df = self.df()
        lut[df.index] = df.bounding_box.zi

        if mask_type == 'contours':
            zi_map = (lut[lamap] + 1.0).astype('int')
            idxs = np.array(zi_map) - 1
            np.put_along_axis(
                zi_st,
                np.expand_dims(idxs, (2, 3)),
                1,
                axis=3
            )

            # change background level from to 0 in final frame
            zi_st[:, :, :, -1][lamap == 0] = 0

        elif mask_type == 'boxes':
            def _set_box(sl):
                zi_st[sl] = True
            self._df.slices.slice.apply(_set_box)
        return zi_st

    @property
    def is_3d(self) -> bool:
        return is_df_3d(self.df())

    def classify_by(
            self, name: str, channels: list[int],
            object_classification_model: InstanceMaskSegmentationModel,
    ):
        """
        Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing
        specified inputs through an instance segmentation classifier.

        :param name: name of column to insert
        :param channels: list of nc raw input channels to send to classifier
        :param object_classification_model: InstanceSegmentation model object
        :return: None
        """
        if self.count == 0:
            self._df['classifications', name] = None
            return True

        input_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None)  # all channels

        # do this on a patch basis, i.e. only one object per frame
        obmap_patches = object_classification_model.label_patch_stack(
            input_acc,
            self.get_patch_masks_acc(expanded=False, pad_to=None)
        )

        se = pd.Series(dtype='Int64', index=self._df.index)

        for i, la in enumerate(self._df.index):
            oc = np.unique(
                mask_largest_object(
                    obmap_patches.iat(i).data_yxz
                )
            )[-1]
            se[la] = oc
        self.set_classification(name, se)

    def get_instance_classification(self, roiset_from: Self, iou_min: float = 0.5) -> pd.DataFrame:
        """
        Transfer instance classification labels from another RoiSet based on intersection over union (IOU) similarity
        :param roiset_from: RoiSet source of classification labels, same shape as this RoiSet
        :param iou_min: threshold IOU below which a label is not transferred
        :return DataFrame of source RoiSet, including overlaps with this RoiSet and IOU metric
        """
        if self.acc_raw.shape != roiset_from.acc_raw.shape:
            raise ShapeMismatchError(
                f'Expecting two RoiSets of same shape: {self.acc_raw.shape} != {roiset_from.acc_raw.shape}')

        columns = roiset_from.classification_columns

        if len(columns) == 0:
            raise MissingInstanceLabelsError('Expecting at least on instance classification channel but none found')

        df_overlaps = filter_df_overlap_seg(
            roiset_from.df(),
            self.df()
        )
        df_overlaps['transfer'] = df_overlaps.seg_iou > iou_min
        df_merge = pd.merge(
            roiset_from.df().classifications[columns],
            df_overlaps.loc[df_overlaps.transfer, ['overlaps_with']],
            left_index=True,
            right_index=True,
            how='inner',
        ).set_index('overlaps_with')
        for col in columns:
            self.set_classification(col, df_merge[col])

        return df_overlaps

    def get_object_class_map(self, class_name: str, filter_by: Union[List, None] = None) -> InMemoryDataAccessor:
        """
        For a given classification result, return a map where object IDs are replaced by each object's class
        :param name: name of the classification result, same as passed to RoiSet.classify_by()
        :param filter_by: only include ROIs if the intersection of all specified classifications is True
        :return: accessor of object class map
        """
        assert class_name in self._df.classifications.columns
        obj_ids = self.acc_obj_ids
        om = np.zeros(obj_ids.shape, obj_ids.dtype)

        def _label_object_class(roi):
            om[self.acc_obj_ids.data == roi.name] = roi[class_name]
        if filter_by is None:
            self._df.classifications.apply(_label_object_class, axis=1)
        else:
            pd_fil = self._df.classifications[filter_by]
            self._df.classifications.loc[pd_fil.all(axis=1), :].apply(_label_object_class, axis=1)
        return InMemoryDataAccessor(om)

    def get_object_identities_overlay_map(
            self,
            white_channel,
            transparency: float = 0.5,
            mip: bool = False,
            rescale_clip: Union[float, None] = None,
    ) -> InMemoryDataAccessor:
        mono = self.acc_raw.get_mono(channel=white_channel)
        if rescale_clip is not None:
            mono = mono.apply(lambda x: rescale(x, clip=rescale_clip))
        mono = mono.to_8bit().data_yxz
        max_label = self.df().index.max()
        palette = np.array([[0, 0, 0]] + glasbey.create_palette(max_label, as_hex=False))
        rgb_8bit_palette = (255 * palette).round().astype('uint8')
        id_map_yxzc = rgb_8bit_palette[self.acc_obj_ids.data_yxz]
        id_map_yxcz = np.moveaxis(
            id_map_yxzc,
            [0, 1, 3, 2],
            [0, 1, 2, 3]
        )
        combined = np.stack([mono, mono, mono], axis=2) + (1.0 - transparency) * id_map_yxcz
        combined_8bit = np.clip(combined, 0, 255).round().astype('uint8')

        acc_out = InMemoryDataAccessor(combined_8bit)
        if mip:
            return acc_out.get_mip()
        else:
            return acc_out

    def export_object_identities_overlay_map(self, where, white_channel, prefix='obmap_overlay', **kwargs) -> str:
        acc = self.get_object_identities_overlay_map(white_channel=white_channel, **kwargs)
        fp = where / (prefix + '.tif')
        return acc.write(fp, composite=True).name

    def get_serializable_dataframe(self) -> pd.DataFrame:
        return self._df.drop([
            ('slices', 'expanded_slice'),
            ('slices', 'slice'),
            ('slices', 'relative_slice'),
            ('masks', 'binary_mask')
        ],
            axis=1
        )

    def export_dataframe(self, csv_path: Path) -> str:
        csv_path.parent.mkdir(parents=True, exist_ok=True)
        self.get_serializable_dataframe().reset_index(
            col_level=1,
        ).to_csv(
            csv_path,
            index=False,
        )
        return csv_path.name

    def export_patch_masks(self, where: Path, prefix='mask', expanded=False, make_3d=True, mask_mip=False, **kwargs) -> pd.DataFrame:
        patches_df = self.get_patch_masks(pad_to=None, expanded=expanded, make_3d=make_3d, mask_mip=mask_mip).copy()
        if 'nz' in patches_df.bounding_box.columns and any(patches_df.bounding_box['nz'] > 1):
            ext = 'tif'
        else:
            ext = 'png'

        def _export_patch_mask(roi):
            patch = InMemoryDataAccessor.from_mono(roi.masks.patch_mask)
            fname = f'{prefix}-la{roi.name:04d}-zi{roi.bounding_box.zi:04d}.{ext}'
            write_accessor_data_to_file(where / fname, patch)
            return fname

        return patches_df.apply(_export_patch_mask, axis=1)

    def export_patches(self, where: Path, prefix='patch', **kwargs) -> pd.Series:
        """
        Export each patch to its own file.
        :param where: location in which to write patch files
        :param prefix: prefix of each patch's filename
        :param kwargs: patch formatting options
        :return: pd.Series of patch paths
        """
        make_3d = kwargs.get('make_3d', False)
        patches_df = self._df.copy()
        patches_df['patches', 'patch'] = self.get_patches(**kwargs)

        def _export_patch(roi):
            patch = InMemoryDataAccessor(roi.patches.patch)
            ext = 'tif' if make_3d or patch.chroma > 3 or kwargs.get('force_tif') else 'png'
            fname = f'{prefix}-la{roi.name:04d}-zi{roi.bounding_box.zi:04d}.{ext}'

            if patch.dtype == 'uint16':
                resampled = patch.to_8bit()
                write_accessor_data_to_file(where / fname, resampled)
            else:
                write_accessor_data_to_file(where / fname, patch)
            return fname

        return patches_df.apply(_export_patch, axis=1)

    def get_patch_masks(self, pad_to: int = None, expanded: bool = False, make_3d=True, mask_mip=False) -> pd.DataFrame:

        if self.count == 0:
            return pd.DataFrame()

        def _make_patch_mask(roi):
            if expanded:
                patch = np.zeros((roi.expanded_bounding_box.ebb_h, roi.expanded_bounding_box.ebb_w, 1, 1), dtype='uint8')
                patch[roi.slices.relative_slice][:, :, 0, 0] = roi.masks.binary_mask * 255
            else:
                patch = (roi.masks.binary_mask * 255).astype('uint8')
            if pad_to:
                patch = pad(patch, pad_to)
            if self.is_3d and make_3d:
                return patch
            elif self.is_3d and mask_mip:
                return np.max(patch, axis=-1)
            elif self.is_3d:
                rzi = roi.bounding_box.zi - roi.bounding_box.z0
                return patch[:, :, rzi: rzi + 1]
            else:
                return np.expand_dims(patch, 2)

        dfe = self._df.copy()
        dfe['masks', 'patch_mask'] = dfe.apply(_make_patch_mask, axis=1)
        return dfe

    def get_patch_masks_acc(self, **kwargs) -> Union[PatchStack, None]:
        if self.count == 0:
            return None
        se_pm = self.get_patch_masks(**kwargs).masks.patch_mask
        se_ext = se_pm.apply(lambda x: np.expand_dims(x, 2))
        return PatchStack(list(se_ext))

    def get_patches(
            self,
            rescale_clip: float = 0.0,
            pad_to: int = None,
            make_3d: bool = False,
            focus_metric: str = None,
            rgb_overlay_channels: list = None,
            rgb_overlay_weights: list = [1.0, 1.0, 1.0],
            white_channel: int = None,
            expanded=False,
            update_focus_zi=False,
            **kwargs
    ) -> pd.Series:

        if self.count == 0:
            return pd.Series()

        # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
        raw = self.acc_raw
        if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)):
            assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None])
            assert len(rgb_overlay_channels) == 3
            assert len(rgb_overlay_weights) == 3

            if white_channel:
                assert white_channel < raw.chroma
                mono = raw.get_mono(white_channel).data_yxz
                stack = np.stack([mono, mono, mono], axis=2)
            else:
                stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype)

            for ii, ci in enumerate(rgb_overlay_channels):
                if ci is None:
                    continue
                assert isinstance(ci, int)
                assert ci < raw.chroma
                stack[:, :, ii, :] = safe_add(
                    stack[:, :, ii, :],  # either black or grayscale channel
                    rgb_overlay_weights[ii],
                    raw.get_mono(ci).data_yxz
                )
        else:
            if white_channel is not None:  # interpret as just a single channel
                assert white_channel < raw.chroma
                annotate_rgb = False
                for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']:
                    ca = kwargs.get(k)
                    if ca is None:
                        continue
                    assert (ca < raw.chroma)
                    if ca != white_channel:
                        annotate_rgb = True
                        break
                if annotate_rgb:  # make RGB patches anyway to include annotation color
                    mono = raw.get_mono(white_channel).data_yxz
                    stack = np.stack([mono, mono, mono], axis=2)
                else:  # make monochrome patches
                    stack = raw.get_mono(white_channel).data
            elif kwargs.get('channels'):
                stack = raw.get_channels(kwargs['channels']).data
            else:
                stack = raw.data

        def _make_patch(roi):  # extract, focus, and annotate a patch
            if expanded:
                patch3d = stack[roi.slices.expanded_slice]
                subpatch = patch3d[roi.slices.relative_slice]
            else:
                patch3d = stack[roi.slices.slice]
                subpatch = patch3d

            ph, pw, pc, pz = patch3d.shape

            # make a 3d patch, focus stays where it is
            if make_3d:
                patch = patch3d.copy()
                zif = roi.bounding_box.zi

            # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
            elif focus_metric is not None:
                foc = focus_metrics()[focus_metric]

                patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)

                for ci in range(0, pc):
                    me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)]
                    zif = np.argmax(me)
                    patch[:, :, ci, 0] = patch3d[:, :, ci, zif]

            # make a 2d patch from middle of z-stack
            else:
                zif = floor(pz / 2)
                patch = patch3d[:, :, :, [zif]]

            assert len(patch.shape) == 4

            mask = np.zeros(patch3d.shape[0:2], dtype=bool)
            if expanded:
                mask[roi.slices.relative_slice[0:2]] = roi.masks.binary_mask
            else:
                mask = roi.masks.binary_mask

            if rescale_clip is not None:
                if rgb_overlay_channels:  # rescale all equally to preserve white balance
                    patch = rescale(patch, rescale_clip)
                else:
                    for ci in range(0, pc):  # rescale channels separately
                        patch[:, :, ci, :] = rescale(patch[:, :, ci, :], rescale_clip)

            if kwargs.get('draw_bounding_box') is True and expanded:
                bci = kwargs.get('bounding_box_channel', 0)
                assert bci < 3
                if bci > 0:
                    patch = make_rgb(patch)
                for zi in range(0, patch.shape[3]):
                    patch[:, :, bci, zi] = draw_box_on_patch(
                        patch[:, :, bci, zi],
                        (
                            (roi.relative_bounding_box.rel_x0, roi.relative_bounding_box.rel_y0),
                            (roi.relative_bounding_box.rel_x1, roi.relative_bounding_box.rel_y1)
                        ),
                        linewidth=kwargs.get('bounding_box_linewidth', 1)
                    )

            if kwargs.get('draw_mask'):
                mci = kwargs.get('mask_channel', 0)
                for zi in range(0, patch.shape[3]):
                    patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]

            if kwargs.get('draw_contour'):
                mci = kwargs.get('contour_channel', 0)

                for zi in range(0, patch.shape[3]):
                    contours = get_safe_contours(mask)
                    patch[:, :, mci, zi] = draw_contours_on_patch(
                        patch[:, :, mci, zi],
                        contours
                    )

            if pad_to and expanded:
                patch = pad(patch, pad_to)
            return {
                'patch': patch,
                'zif': (roi.bounding_box.z0 + zif) if hasattr(roi.bounding_box, 'z0') else roi.bounding_box.zi,
            }

        df_processed_patches = self._df.apply(lambda r: _make_patch(r), axis=1, result_type='expand')
        if update_focus_zi:
            self._df.loc[:, ('bounding_box', 'zi')] = df_processed_patches['zif']
        return df_processed_patches['patch']

    @property
    def classification_columns(self) -> List[str]:
        """
        Return list of columns that describe instance classification results
        """
        if (dfc := self._df.get('classifications')) is None:
            return []
        return dfc.columns.to_list()

    def set_classification(self, classification_class: str, se: pd.Series):
        """
        Set instance classification result as a column addition on dataframe
        :param classification_class: name of classification result
        :param se: series containing class information
        """
        self._df['classifications', classification_class] = se

    def run_exports(self, where: Path, prefix, params: RoiSetExportParams) -> dict:
        """
        Export various representations of ROIs, e.g. patches, annotated stacks, and object maps.
        :param where: path of directory in which to write all export products
        :param prefix: prefix of the name of each product's file or subfolder
        :param params: RoiSetExportParams object describing which products to export and with which parameters
        :return: nested dict of Path objects describing the location of export products
        """
        record = {}
        if not self.count:
            return

        for k, kp in params.dict().items():
            if kp is None:
                continue
            if k == 'patches':
                for pk, pp in kp.items():
                    product_name = f'patches_{pk}'
                    subdir = Path(product_name)
                    if params.write_patches_to_subdirectory:
                        subdir = subdir / prefix

                    if pp['is_patch_mask']:
                        se_paths = self.export_patch_masks(where / subdir, prefix=prefix, **pp).apply(lambda x: str(subdir / x))
                    else:
                        se_paths = self.export_patches(where / subdir, prefix=prefix, **pp).apply(lambda x: str(subdir / x))

                    df_patch_info = pd.DataFrame({
                        f'{product_name}_path': se_paths,
                        f'{product_name}_id': se_paths.apply(lambda _: uuid4()),
                    })
                    insert_level(df_patch_info, 'patches')
                    self._df = self._df.join(df_patch_info)
                    assert isinstance(self._df.columns, pd.MultiIndex)
                    record[product_name] = list(se_paths)
            if k == 'annotated_zstacks':
                record[k] = str(Path(k) / self.export_annotated_zstack(where / k, prefix=prefix, **kp))
            if k == 'object_classes':
                for n in self.classification_columns:
                    fp = where / k / n / (prefix + '.tif')
                    write_accessor_data_to_file(fp, self.get_object_class_map(n))
                    record[f'{k}_{n}'] = str(fp)
            if k == 'derived_channels':
                record[k] = []
                for di, dacc in enumerate(self.accs_derived):
                    fp = where / k / f'dc{di:01d}.tif'
                    fp.parent.mkdir(exist_ok=True, parents=True)
                    dacc.export_pyxcz(fp)
                    record[k].append(str(fp))
            if k == 'labels_overlay':
                fn = self.export_object_identities_overlay_map(where / k, prefix=prefix, **kp)
                record[k] = str(Path(k) / fn)

        # export dataframe and patch masks
        record = {
            **record,
            **self.serialize(
                where,
                prefix=prefix,
                write_patches_to_subdirectory=params.write_patches_to_subdirectory,
            ),
        }

        return record

    def get_export_product_accessors(self, params: RoiSetExportParams) -> dict:
        """
        Return various representations of ROIs, e.g. patches, annotated stacks, and object maps, as accessors
        :param params: RoiSetExportParams object describing which products to export and with which parameters
        :return: ordered dict of accessors containing the specified products
        """
        interm = OrderedDict()
        if not self.count:
            return interm

        for k, kp in params.dict().items():
            if kp is None:
                continue
            if k == 'patches':
                for pk, pp in kp.items():
                    if pp['is_patch_mask']:
                        interm[f'patche_masks_{pk}'] = self.get_patch_masks_acc(**pp)
                    else:
                        interm[f'patches_{pk}'] = self.get_patches_acc(**pp)
            if k == 'annotated_zstacks':
                interm[k] = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kp))
            if k == 'object_classes':
                # pr = 'classify_by_'
                # cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)]
                for n in self.classification_columns:
                    interm[f'{k}_{n}'] = self.get_object_class_map(n)
            if k == 'labels_overlay':
                interm[k] = self.get_object_identities_overlay_map(**kp)

        return interm

    def serialize(self, where: Path, prefix='roiset', allow_overwrite=True, write_patches_to_subdirectory=False) -> dict:
        """
        Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks
        :param where: path of directory in which to write files
        :param prefix: (optional) prefix
        :param allow_overwrite: freely overwrite CSV file of same name if True
        :param write_patches_to_subdirectory: if True, write all patches to a subdirectory with prefix as name
        :return: nested dict of Path objects describing the locations of export products
        """
        record = {}
        if not self._df.masks.binary_mask.apply(lambda x: np.all(x)).all():  # binary masks aren't just all True
            subdir = Path('tight_patch_masks')
            if write_patches_to_subdirectory:
                subdir = subdir / prefix
            se_exp = self.export_patch_masks(
                where / subdir,
                prefix=prefix,
                pad_to=None,
                expanded=False
            )
            # record patch masks paths to dataframe, then save static columns to CSV
            se_pa = se_exp.apply(
                lambda x: str(subdir / x)
            ).rename('tight_patch_masks_path')
            self._df['patches', 'tight_patch_masks_path'] = se_exp.apply(lambda x: str(subdir / x))
            record['tight_patch_masks'] = list(se_pa)

        csv_path = where / 'dataframe' / (prefix + '.csv')
        if not allow_overwrite and csv_path.exists():
            raise SerializeRoiSetError(f'Cannot overwrite RoiSet file {csv_path.__str__()}')

        csv_path.parent.mkdir(parents=True, exist_ok=True)
        self.export_dataframe(csv_path)

        record['dataframe'] = str(Path('dataframe') / csv_path.name)

        return record


    def get_polygons(self, poly_threshold=0, dilation_radius=1) -> pd.DataFrame:
        self.coordinates_ = """
        Fit polygons to all object boundaries in the RoiSet
        :param poly_threshold: threshold distance for polygon fit; a smaller number follows sharp features more closely
        :param dilation_radius: radius of binary dilation to apply before fitting polygon
        :return: Series of (variable x 2) np.ndarrays describing (x, y) polymer coordinates
        """

        pad_to = 1

        def _poly_from_mask(roi):
            mask = roi.masks.binary_mask
            if len(mask.shape) != 2:
                raise PatchShapeError(f'Patch mask needs to be two dimensions to fit a polygon')

            # label and fill holes
            labeled = label(mask)
            filled = [rp.image_filled for rp in regionprops(labeled)]
            assert (np.unique(labeled)[-1] == 1) and (len(filled) == 1), 'Cannot fit multiple polygons in a single patch mask'

            closed = binary_dilation(filled[0], footprint=disk(dilation_radius))
            padded = np.pad(closed, pad_to) * 1.0
            all_contours = find_contours(padded)

            nc = len(all_contours)
            for j in range(0, nc):
                if all([points_in_poly(all_contours[k], all_contours[j]).all() for k in range(0, nc)]):
                    contour = all_contours[j]
                    break

            rel_polygon = approximate_polygon(contour[:, [1, 0]], poly_threshold) - [pad_to, pad_to]
            return rel_polygon + [roi.bounding_box.x0, roi.bounding_box.y0]

        return self._df.apply(_poly_from_mask, axis=1)


    @property
    def acc_obj_ids(self):
        return make_object_ids_from_df(self._df, self.acc_raw.shape_dict)


    @classmethod
    def deserialize(cls, acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self:
        """
        Create an RoiSet object from saved files and an image accessor
        :param acc_raw: accessor to image that contains ROIs
        :param where: path to directory containing RoiSet serialization files, namely dataframe.csv and a subdirectory
            named tight_patch_masks
        :param prefix: starting prefix of patch mask filenames
        :return: RoiSet object
        """
        df = read_roiset_df(where / 'dataframe' / (prefix + '.csv'))
        df.index.name = 'label'
        pa_masks = where / 'tight_patch_masks'
        is_3d = is_df_3d(df)
        ext = 'tif' if is_3d else 'png'

        if pa_masks.exists():  # import segmentation masks
            def _read_binary_mask(r):
                fname = f'{prefix}-la{r.name:04d}-zi{r.bounding_box.zi:04d}.{ext}'
                try:
                    ma_acc = generate_file_accessor(pa_masks / fname)
                    if is_3d:
                        mask_data = ma_acc.data_yxz / ma_acc.dtype_max
                    else:
                        mask_data = ma_acc.data_yx / ma_acc.dtype_max
                    return mask_data
                except Exception as e:
                    raise DeserializeRoiSetError(e)

            df['masks', 'binary_mask'] = df.apply(_read_binary_mask, axis=1)
            id_mask = make_object_ids_from_df(df, acc_raw.shape_dict)
            return cls.from_object_ids(acc_raw, id_mask)

        else:  # assume bounding boxes, exclusively 2d objects
            df['bounding_box', 'y'] = df.bounding_box['y0']
            df['bounding_box', 'x'] = df.bounding_box['x0']
            df['bounding_box', 'h'] = df.bounding_box['y1'] - df.bounding_box['y0']
            df['bounding_box', 'w'] = df.bounding_box['x1'] - df.bounding_box['x0']
            return cls.from_bounding_boxes(
                acc_raw,
                df.bounding_box[['y', 'x', 'h', 'w']].to_dict(orient='records'),
                list(df.bounding_box['zi'])
            )


class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams):
    derived_channels: bool = False


class RoiSetWithDerivedChannels(RoiSet):

    def __init__(self, *a, **k):
        self.accs_derived = []
        super().__init__(*a, **k)

    def classify_by(
            self, name: str, channels: list[int],
            object_classification_model: InstanceMaskSegmentationModel,
            derived_channel_functions: list[callable] = None
    ):
        """
        Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing
        specified inputs through an instance segmentation classifier.  Derive additional inputs for object
        classification by passing a raw input channel through one or more functions.

        :param name: name of column to insert
        :param channels: list of nc raw input channels to send to classifier
        :param object_classification_model: InstanceSegmentation model object
        :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and
            that return a single-channel PatchStack accessor of the same shape
        :return: None
        """

        acc_in = self.get_patches_acc(channels=channels, expanded=False, pad_to=None)
        if derived_channel_functions is not None:
            for fcn in derived_channel_functions:
                der = fcn(acc_in) # returns patch stack
                self.accs_derived.append(der)

            # combine channels
            acc_app = acc_in
            for acc_der in self.accs_derived:
                acc_app = acc_app.append_channels(acc_der)

        else:
            acc_app = acc_in

        # do this on a patch basis, i.e. only one object per frame
        obmap_patches = object_classification_model.label_patch_stack(
            acc_app,
            self.get_patch_masks_acc(expanded=False, pad_to=None)
        )

        self._df['classify_by_' + name] = pd.Series(dtype='Int16')

        se = pd.Series(dtype='Int64', index=self._df.index)

        for i, la in enumerate(self._df.index):
            oc = np.unique(
                mask_largest_object(
                    obmap_patches.iat(i).data_yxz
                )
            )[-1]
            se[la] = oc
        self.set_classification(name, se)

    def run_exports(self, where: Path, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict:
        """
        Export various representations of ROIs, e.g. patches, annotated stacks, and object maps.
        :param where: path of directory in which to write all export products
        :param prefix: prefix of the name of each product's file or subfolder
        :param params: RoiSetExportParams object describing which products to export and with which parameters
        :return: nested dict of Path objects describing the location of export products
        """
        record = super().run_exports(where, prefix, params)

        k = 'derived_channels'
        if k in params.dict().keys():
            record[k] = []
            for di, dacc in enumerate(self.accs_derived):
                fp = where / k / f'dc{di:01d}.tif'
                fp.parent.mkdir(exist_ok=True, parents=True)
                dacc.export_pyxcz(fp)
                record[k].append(str(fp))
        return record


class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel):
    def __init__(self, tr: float = 0.5):
        """
        Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all
        objects as class 1 if threshold = 0.0
        :param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range
        :param channel: channel to use for thresholding
        """
        self.tr = tr
        self.loaded = self.load()
        super().__init__(info={'tr': tr})

    def load(self):
        return True

    def infer(
            self,
            img: GenericImageDataAccessor,
            mask: GenericImageDataAccessor,
            allow_3d: bool = False,
            connect_3d: bool = True,
    ) -> GenericImageDataAccessor:
        if img.chroma != 1:
            raise ShapeMismatchError(
                f'IntensityThresholdInstanceMaskSegmentationModel expects 1 channel but received {img.chroma}'
            )
        if isinstance(img, PatchStack):  # assume one object per patch
            df = img.get_object_df(mask)
            om = np.zeros(mask.shape, 'uint16')
            def _label_patch_class(la):
                om[la] = (mask.iat(la).data > 0) * 1
            df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_patch_class)
            return PatchStack(om)
        else:
            labels = get_label_ids(mask)
            df = pd.DataFrame(regionprops_table(
                labels.data_yxz,
                intensity_image=img.data_yxz,
                properties=('label', 'area', 'intensity_mean')
            ))

            om = np.zeros(labels.shape, labels.dtype)
            def _label_object_class(la):
                om[labels.data == la] = 1
            df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_object_class)
            return InMemoryDataAccessor(om)

    def label_instance_class(
            self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
    ) -> GenericImageDataAccessor:
        super().label_instance_class(img, mask, **kwargs)
        return self.infer(img, mask)


class Error(Exception):
    pass


class BoundingBoxError(Error):
    pass


class DataFrameQueryError(Error):
    pass


class DeserializeRoiSetError(Error):
    pass


class SerializeRoiSetError(Error):
    pass


class NoDeprojectChannelSpecifiedError(Error):
    pass


class DerivedChannelError(Error):
    pass


class MissingSegmentationError(Error):
    pass


class PatchShapeError(Error):
    pass


class ShapeMismatchError(Error):
    pass


class MissingInstanceLabelsError(Error):
    pass