Skip to content
Snippets Groups Projects
roiset.py 39.2 KiB
Newer Older
from math import floor
from typing import Dict, List, Union
from typing_extensions import Self
import numpy as np
import pandas as pd
from pydantic import BaseModel, Field
from skimage.measure import approximate_polygon, find_contours, label, points_in_poly, regionprops
from skimage.morphology import binary_dilation, disk
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.base.models import InstanceMaskSegmentationModel
from model_server.base.process import get_safe_contours, pad, rescale, make_rgb, safe_add
from model_server.base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
from model_server.base.accessors import generate_file_accessor, PatchStack
from model_server.base.process import mask_largest_object
from model_server.rois.df import filter_df_overlap_seg, is_df_3d, insert_level, read_roiset_df, df_insert_slices
from model_server.rois.labels import get_label_ids, focus_metrics, make_df_from_object_ids, make_object_ids_from_df, \
    NoDeprojectChannelSpecifiedError
class PatchParams(BaseModel):
    is_patch_mask: bool = False
    white_channel: Union[int, None] = None
    channels: Union[List[int], None] = None
    draw_bounding_box: bool = False
    draw_contour: bool = False
    draw_mask: bool = False
    rescale_clip: Union[float, None] = None
    focus_metric: Union[str, None] = 'max_sobel'
    rgb_overlay_channels: Union[List[Union[int, None]], None] = None
    rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0]
    pad_to: Union[int, None] = 256
    force_tif: bool = False
    update_focus_zi: bool = Field(
        False,
        description='If generating 2d patches with a 3d focus metric, update the RoiSet patch focus zi accordingly'
    )
    mask_mip: bool = Field(
        False,
        description='If generating 2d patch masks, use the MIP of the 3d patch mask instead of existing zi focus'
    )
class AnnotatedPatchParams(PatchParams):
    bounding_box_channel: int = 1
    bounding_box_linewidth: int = 2

class RoiSetLabelsOverlayParams(BaseModel):
    white_channel: int
    transparency: float = 0.5
    mip: bool = False
    rescale_clip: Union[float, None] = None

class AnnotatedZStackParams(BaseModel):
    draw_label: bool = False
    channel: Union[int, None] = None


class RoiFilterRange(BaseModel):
    min: float
    max: float


class RoiFilter(BaseModel):
    area: Union[RoiFilterRange, None] = None
    diag: Union[RoiFilterRange, None] = None
    min_hw: Union[RoiFilterRange, None] = None


class RoiSetMetaParams(BaseModel):
    filters: Union[RoiFilter, None] = None
    expand_box_by: List[int] = [0, 0]
    deproject_intensity_threshold: float = 0.0


class RoiSetExportParams(BaseModel):
    patches: Union[Dict[str, PatchParams], None] = None
    annotated_zstacks: Union[AnnotatedZStackParams, None] = None
    object_classes: bool = False
    labels_overlay: Union[RoiSetLabelsOverlayParams, None] = None
    write_patches_to_subdirectory: bool = Field(
        False,
        description='Write all patches to a subdirectory with prefix as name'
    )
            params: RoiSetMetaParams = RoiSetMetaParams(),
        """
        A set of regions of interest, referenced by their positions and contours in the YXCZ space of stack acc_raw.
        RoiSet contains their internal state, which may be exported as patches, maps, and other products by export methods.
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param df: dataframe containing at minimum bounding box and segmentation mask information
        :param params: optional arguments that influence the definition and representation of ROIs
        """
        df['patches', 'index'] = df.reset_index().index
            acc_raw: GenericImageDataAccessor,
            acc_obj_ids: GenericImageDataAccessor,
            params: RoiSetMetaParams = RoiSetMetaParams(),
    ):
        Create an RoiSet from an object identities map
        :param acc_raw: accessor to a generally multichannel z-stack
        :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity
            labels its membership in a connected object
        :param params: optional arguments that influence the definition and representation of ROIs
        :return: RoiSet object
        df = make_df_from_object_ids(
            acc_raw, acc_obj_ids,
            expand_box_by=params.expand_box_by,
            deproject_channel=params.deproject_channel,
            deproject_intensity_threshold=params.deproject_intensity_threshold,
    def from_bounding_boxes(
        acc_raw: GenericImageDataAccessor,
        bbox_yxhw: List[Dict],
        bbox_zi: Union[List[int], int] = None,
        params: RoiSetMetaParams = RoiSetMetaParams()
    ):
        """
        Create and RoiSet from bounding boxes
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param yxhw_list: list of bounding boxing coordinates [corner X, corner Y, height, width]
        :param params: optional arguments that influence the definition and representation of ROIs
        :return: RoiSet object
        """
        bbox_df = pd.DataFrame(bbox_yxhw)
        if list(bbox_df.columns.str.upper().sort_values()) != ['H', 'W', 'X', 'Y']:
            raise BoundingBoxError(f'Expecting bounding box coordinates Y, X, H, and W, not {list(bbox_df.columns)}')

        # deproject if zi is not specified
        if bbox_zi is None:
            dch = params.deproject_channel
            if dch is None or dch >= acc_raw.chroma or dch < 0:
                if acc_raw.chroma == 1:
                    dch = 0
                else:
                    raise NoDeprojectChannelSpecifiedError(
                        f'When labeling objects, either their z-coordinates or a valid deprojection channel are required.'
                    )
            bbox_df['zi'] = acc_raw.get_mono(dch).get_focus_vector().argmax()
        else:
            bbox_df['zi'] = bbox_zi
        bbox_df['y0'] = bbox_df['y']
        bbox_df['x0'] = bbox_df['x']
        bbox_df['y1'] = bbox_df['y0'] + bbox_df['h']
        bbox_df['x1'] = bbox_df['x0'] + bbox_df['w']
        bbox_df.set_index('label', inplace=True)
        bbox_df = bbox_df.drop(['x', 'y', 'w', 'h'], axis=1)
        insert_level(bbox_df, 'bounding_box')
            return np.ones((int(r.h), int(r.w)), dtype=bool)
        df['masks', 'binary_mask'] = df.bounding_box.apply(
    def from_binary_mask(
            acc_raw: GenericImageDataAccessor,
            acc_seg: GenericImageDataAccessor,
            allow_3d=False,
            connect_3d=True,
            params: RoiSetMetaParams = RoiSetMetaParams()
    ):
        """
        Create a RoiSet from a binary segmentation mask (either 2D or 3D)
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param acc_seg: accessor of a binary segmentation mask (mono) of either two or three dimensions
        :param allow_3d: use a 3D map if True; use a 2D map of the mask's maximum intensity project if False
        :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False
        :param params: optional arguments that influence the definition and representation of ROIs
        """
            acc_raw,
            get_label_ids(
                acc_seg,
                allow_3d=allow_3d,
                connect_3d=connect_3d
            ),
            params
        )
            acc_raw,
            polygons: List[np.ndarray],
            params: RoiSetMetaParams = RoiSetMetaParams()
    ):
        """
        Create a RoiSet where objects are defined from a list of polygon coordinates
        :param acc_raw: accessor to a generally a multichannel z-stack
        :param polygons: list of (variable x 2) np.ndarrays describing (x, y) polymer coordinates
        :param params: optional arguments that influence the definition and representation of ROIs
        """
        mask = np.zeros(acc_raw.get_mono(0, mip=True).shape, dtype=bool)
        for p in polygons:
            sl = draw.polygon(p[:, 1], p[:, 0])
            mask[sl] = True
            acc_raw,
            InMemoryDataAccessor(mask),
            allow_3d=False,
            connect_3d=False,
            params=params,
        )
    @property
    def info(self):
        return {
            'raw_shape_dict': self.acc_raw.shape_dict,
            'count': self.count,
            'classify_by': self.classification_columns,
            'df_memory_usage': int(self.df().memory_usage(deep=True).sum())
    def df(self) -> pd.DataFrame:
        return self._df.copy()

    def list_bounding_boxes(self):
        return self.df.bounding_box.reset_index().to_dict(orient='records')
    def get_slices(self) -> pd.Series:
        return self.df()['slices', 'slice']
    def add_df_col(self, name, se: pd.Series) -> None:
        self._df[name] = se
    def get_patches_acc(self, channels: list = None, **kwargs) -> Union[PatchStack, None]:  # padded, un-annotated 2d patches
        if self.count == 0:
            return None
            return PatchStack(list(self.get_patches(white_channel=channels[0], **kwargs)))
            return PatchStack(list(self.get_patches(channels=channels, **kwargs)))
    def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> str:
        annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs))
        fp = where / (prefix + '.tif')
        write_accessor_data_to_file(fp, annotated)
    def get_zmask(self, mask_type='boxes') -> np.ndarray:
        """
        Return a mask of same dimensionality as raw data

        :param kwargs: variable-length keyword arguments
            mask_type: if 'boxes', zmask is True in each object's complete bounding box; otherwise 'contours'
        """

        assert mask_type in ('contours', 'boxes')
        zi_st = np.zeros(self.acc_raw.shape, dtype='bool')
        lamap = self.acc_obj_ids.data

        # make an object map where label is replaced by focus position in stack and background is -1
        lut = np.zeros(lamap.max() + 1) - 1

        if mask_type == 'contours':
            zi_map = (lut[lamap] + 1.0).astype('int')
            idxs = np.array(zi_map) - 1
            np.put_along_axis(
                zi_st,
                np.expand_dims(idxs, (2, 3)),
                1,
                axis=3
            )

            # change background level from to 0 in final frame
            zi_st[:, :, :, -1][lamap == 0] = 0

        elif mask_type == 'boxes':
            def _set_box(sl):
                zi_st[sl] = True
            self._df.slices.slice.apply(_set_box)
    def classify_by(
            self, name: str, channels: list[int],
            object_classification_model: InstanceMaskSegmentationModel,
    ):
        """
        Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing
        specified inputs through an instance segmentation classifier.

        :param name: name of column to insert
        :param channels: list of nc raw input channels to send to classifier
        :param object_classification_model: InstanceSegmentation model object
        :return: None
        """
            self._df['classifications', name] = None
        input_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None)  # all channels
        # do this on a patch basis, i.e. only one object per frame
        obmap_patches = object_classification_model.label_patch_stack(
            self.get_patch_masks_acc(expanded=False, pad_to=None)
        se = pd.Series(dtype='Int64', index=self._df.index)
        for i, la in enumerate(self._df.index):
            oc = np.unique(
                mask_largest_object(
                    obmap_patches.iat(i).data_yxz
        self.set_classification(name, se)
    def get_instance_classification(self, roiset_from: Self, iou_min: float = 0.5) -> pd.DataFrame:
        """
        Transfer instance classification labels from another RoiSet based on intersection over union (IOU) similarity
        :param roiset_from: RoiSet source of classification labels, same shape as this RoiSet
        :param iou_min: threshold IOU below which a label is not transferred
        :return DataFrame of source RoiSet, including overlaps with this RoiSet and IOU metric
        """
        if self.acc_raw.shape != roiset_from.acc_raw.shape:
            raise ShapeMismatchError(
                f'Expecting two RoiSets of same shape: {self.acc_raw.shape} != {roiset_from.acc_raw.shape}')

        columns = roiset_from.classification_columns

        if len(columns) == 0:
            raise MissingInstanceLabelsError('Expecting at least on instance classification channel but none found')

        df_overlaps = filter_df_overlap_seg(
        df_overlaps['transfer'] = df_overlaps.seg_iou > iou_min
            roiset_from.df().classifications[columns],
            df_overlaps.loc[df_overlaps.transfer, ['overlaps_with']],
            left_index=True,
            right_index=True,
            how='inner',
        ).set_index('overlaps_with')
        for col in columns:
            self.set_classification(col, df_merge[col])

    def get_object_class_map(self, class_name: str, filter_by: Union[List, None] = None) -> InMemoryDataAccessor:
        """
        For a given classification result, return a map where object IDs are replaced by each object's class
        :param name: name of the classification result, same as passed to RoiSet.classify_by()
        :param filter_by: only include ROIs if the intersection of all specified classifications is True
        assert class_name in self._df.classifications.columns
        obj_ids = self.acc_obj_ids
        om = np.zeros(obj_ids.shape, obj_ids.dtype)

        def _label_object_class(roi):
            om[self.acc_obj_ids.data == roi.name] = roi[class_name]
            self._df.classifications.apply(_label_object_class, axis=1)
            pd_fil = self._df.classifications[filter_by]
            self._df.classifications.loc[pd_fil.all(axis=1), :].apply(_label_object_class, axis=1)
    def get_object_identities_overlay_map(
            self,
            white_channel,
            transparency: float = 0.5,
            mip: bool = False,
        mono = self.acc_raw.get_mono(channel=white_channel)
        if rescale_clip is not None:
            mono = mono.apply(lambda x: rescale(x, clip=rescale_clip))
        mono = mono.to_8bit().data_yxz
        max_label = self.df().index.max()
        palette = np.array([[0, 0, 0]] + glasbey.create_palette(max_label, as_hex=False))
        rgb_8bit_palette = (255 * palette).round().astype('uint8')
        id_map_yxzc = rgb_8bit_palette[self.acc_obj_ids.data_yxz]
        id_map_yxcz = np.moveaxis(
            id_map_yxzc,
            [0, 1, 3, 2],
            [0, 1, 2, 3]
        )
        combined = np.stack([mono, mono, mono], axis=2) + (1.0 - transparency) * id_map_yxcz
        combined_8bit = np.clip(combined, 0, 255).round().astype('uint8')

        acc_out = InMemoryDataAccessor(combined_8bit)
        if mip:
            return acc_out.get_mip()
        else:
            return acc_out
    def export_object_identities_overlay_map(self, where, white_channel, prefix='obmap_overlay', **kwargs) -> str:
        acc = self.get_object_identities_overlay_map(white_channel=white_channel, **kwargs)
        fp = where / (prefix + '.tif')
        return acc.write(fp, composite=True).name

    def get_serializable_dataframe(self) -> pd.DataFrame:
            ('slices', 'expanded_slice'),
            ('slices', 'slice'),
            ('slices', 'relative_slice'),
            ('masks', 'binary_mask')
    def export_dataframe(self, csv_path: Path) -> str:
        csv_path.parent.mkdir(parents=True, exist_ok=True)
        self.get_serializable_dataframe().reset_index(
            col_level=1,
        ).to_csv(
            csv_path,
            index=False,
        )
    def export_patch_masks(self, where: Path, prefix='mask', expanded=False, make_3d=True, mask_mip=False, **kwargs) -> pd.DataFrame:
        patches_df = self.get_patch_masks(pad_to=None, expanded=expanded, make_3d=make_3d, mask_mip=mask_mip).copy()
        if 'nz' in patches_df.bounding_box.columns and any(patches_df.bounding_box['nz'] > 1):
            patch = InMemoryDataAccessor.from_mono(roi.masks.patch_mask)
            fname = f'{prefix}-la{roi.name:04d}-zi{roi.bounding_box.zi:04d}.{ext}'
            write_accessor_data_to_file(where / fname, patch)
        return patches_df.apply(_export_patch_mask, axis=1)
    def export_patches(self, where: Path, prefix='patch', **kwargs) -> pd.Series:
        """
        Export each patch to its own file.
        :param where: location in which to write patch files
        :param prefix: prefix of each patch's filename
        :param kwargs: patch formatting options
        :return: pd.Series of patch paths
        """
        make_3d = kwargs.get('make_3d', False)
        patches_df = self._df.copy()
        patches_df['patches', 'patch'] = self.get_patches(**kwargs)
            patch = InMemoryDataAccessor(roi.patches.patch)
            ext = 'tif' if make_3d or patch.chroma > 3 or kwargs.get('force_tif') else 'png'
            fname = f'{prefix}-la{roi.name:04d}-zi{roi.bounding_box.zi:04d}.{ext}'
                resampled = patch.to_8bit()
                write_accessor_data_to_file(where / fname, resampled)
                write_accessor_data_to_file(where / fname, patch)
        return patches_df.apply(_export_patch, axis=1)
    def get_patch_masks(self, pad_to: int = None, expanded: bool = False, make_3d=True, mask_mip=False) -> pd.DataFrame:
                patch = np.zeros((roi.expanded_bounding_box.ebb_h, roi.expanded_bounding_box.ebb_w, 1, 1), dtype='uint8')
                patch[roi.slices.relative_slice][:, :, 0, 0] = roi.masks.binary_mask * 255
                patch = (roi.masks.binary_mask * 255).astype('uint8')
            if pad_to:
                patch = pad(patch, pad_to)
            if self.is_3d and make_3d:
            elif self.is_3d and mask_mip:
                return np.max(patch, axis=-1)
                rzi = roi.bounding_box.zi - roi.bounding_box.z0
                return patch[:, :, rzi: rzi + 1]
                return np.expand_dims(patch, 2)
        dfe['masks', 'patch_mask'] = dfe.apply(_make_patch_mask, axis=1)
    def get_patch_masks_acc(self, **kwargs) -> Union[PatchStack, None]:
        if self.count == 0:
            return None
        se_pm = self.get_patch_masks(**kwargs).masks.patch_mask
        se_ext = se_pm.apply(lambda x: np.expand_dims(x, 2))
        return PatchStack(list(se_ext))
    def get_patch_obmap_acc(self, **kwargs) -> Union[PatchStack, None]:
        labels = self.df().index.sort_values().to_list()
        acc_masks = self.get_patch_masks_acc(**kwargs)
        return PatchStack([(acc_masks.iat(i).data > 0) * labels[i] for i in range(0, len(labels))])

    def get_patches(
            self,
            rescale_clip: float = 0.0,
            make_3d: bool = False,
            focus_metric: str = None,
            rgb_overlay_channels: list = None,
            rgb_overlay_weights: list = [1.0, 1.0, 1.0],
            white_channel: int = None,
        # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
        raw = self.acc_raw
        if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)):
            assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None])
            assert len(rgb_overlay_channels) == 3
            assert len(rgb_overlay_weights) == 3

            if white_channel:
                assert white_channel < raw.chroma
                mono = raw.get_mono(white_channel).data_yxz
                stack = np.stack([mono, mono, mono], axis=2)
            else:
                stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype)

            for ii, ci in enumerate(rgb_overlay_channels):
                if ci is None:
                    continue
                assert isinstance(ci, int)
                assert ci < raw.chroma
                stack[:, :, ii, :] = safe_add(
                    stack[:, :, ii, :],  # either black or grayscale channel
                    rgb_overlay_weights[ii],
            if white_channel is not None:  # interpret as just a single channel
                assert white_channel < raw.chroma
                annotate_rgb = False
                for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']:
                    ca = kwargs.get(k)
                    if ca is None:
                        continue
                    assert (ca < raw.chroma)
                    if ca != white_channel:
                        annotate_rgb = True
                        break
                if annotate_rgb:  # make RGB patches anyway to include annotation color
                    mono = raw.get_mono(white_channel).data_yxz
                    stack = np.stack([mono, mono, mono], axis=2)
                else:  # make monochrome patches
                    stack = raw.get_mono(white_channel).data
            elif kwargs.get('channels'):
                stack = raw.get_channels(kwargs['channels']).data
        def _make_patch(roi):  # extract, focus, and annotate a patch
                patch3d = stack[roi.slices.expanded_slice]
                subpatch = patch3d[roi.slices.relative_slice]
            # make a 3d patch, focus stays where it is

            # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
            elif focus_metric is not None:
                foc = focus_metrics()[focus_metric]

                patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)

                for ci in range(0, pc):
                    me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)]
                    zif = np.argmax(me)
                    patch[:, :, ci, 0] = patch3d[:, :, ci, zif]

            # make a 2d patch from middle of z-stack
            else:
                zif = floor(pz / 2)
                patch = patch3d[:, :, :, [zif]]
            mask = np.zeros(patch3d.shape[0:2], dtype=bool)
            if expanded:
                mask[roi.slices.relative_slice[0:2]] = roi.masks.binary_mask
                if rgb_overlay_channels:  # rescale all equally to preserve white balance
                    patch = rescale(patch, rescale_clip)
                else:
                    for ci in range(0, pc):  # rescale channels separately
                        patch[:, :, ci, :] = rescale(patch[:, :, ci, :], rescale_clip)
            if kwargs.get('draw_bounding_box') is True and expanded:
                bci = kwargs.get('bounding_box_channel', 0)
                assert bci < 3
                if bci > 0:
                    patch = make_rgb(patch)
                for zi in range(0, patch.shape[3]):
                    patch[:, :, bci, zi] = draw_box_on_patch(
                        patch[:, :, bci, zi],
                        (
                            (roi.relative_bounding_box.rel_x0, roi.relative_bounding_box.rel_y0),
                            (roi.relative_bounding_box.rel_x1, roi.relative_bounding_box.rel_y1)
                        ),
                        linewidth=kwargs.get('bounding_box_linewidth', 1)
                    )

            if kwargs.get('draw_mask'):
                mci = kwargs.get('mask_channel', 0)
                for zi in range(0, patch.shape[3]):
                    patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]

            if kwargs.get('draw_contour'):
                mci = kwargs.get('contour_channel', 0)

                for zi in range(0, patch.shape[3]):
                    patch[:, :, mci, zi] = draw_contours_on_patch(
                        patch[:, :, mci, zi],
            if pad_to and expanded:
                'zif': (roi.bounding_box.z0 + zif) if hasattr(roi.bounding_box, 'z0') else roi.bounding_box.zi,
        df_processed_patches = self._df.apply(lambda r: _make_patch(r), axis=1, result_type='expand')
        if update_focus_zi:
            self._df.loc[:, ('bounding_box', 'zi')] = df_processed_patches['zif']
        return df_processed_patches['patch']
    def classification_columns(self) -> List[str]:
        """
        Return list of columns that describe instance classification results
        """
        if (dfc := self._df.get('classifications')) is None:
            return []
        return dfc.columns.to_list()
    def set_classification(self, classification_class: str, se: pd.Series):
        """
        Set instance classification result as a column addition on dataframe
        :param classification_class: name of classification result
        :param se: series containing class information
        """
        self._df['classifications', classification_class] = se
    def run_exports(self, where: Path, prefix, params: RoiSetExportParams) -> dict:
        """
        Export various representations of ROIs, e.g. patches, annotated stacks, and object maps.
        :param where: path of directory in which to write all export products
        :param prefix: prefix of the name of each product's file or subfolder
        :param params: RoiSetExportParams object describing which products to export and with which parameters
        :return: nested dict of Path objects describing the location of export products
            if k == 'patches':
                for pk, pp in kp.items():
                    product_name = f'patches_{pk}'
                    subdir = Path(product_name)
                    if params.write_patches_to_subdirectory:
                        subdir = subdir / prefix

                    if pp['is_patch_mask']:
                        se_paths = self.export_patch_masks(where / subdir, prefix=prefix, **pp).apply(lambda x: str(subdir / x))
                    else:
                        se_paths = self.export_patches(where / subdir, prefix=prefix, **pp).apply(lambda x: str(subdir / x))

                    df_patch_info = pd.DataFrame({
                        f'{product_name}_path': se_paths,
                        f'{product_name}_id': se_paths.apply(lambda _: uuid4()),
                    })
                    insert_level(df_patch_info, 'patches')
                    assert isinstance(self._df.columns, pd.MultiIndex)
                record[k] = str(Path(k) / self.export_annotated_zstack(where / k, prefix=prefix, **kp))
            if k == 'object_classes':
                    write_accessor_data_to_file(fp, self.get_object_class_map(n))
                    record[f'{k}_{n}'] = str(fp)
            if k == 'derived_channels':
                record[k] = []
                for di, dacc in enumerate(self.accs_derived):
                    fp.parent.mkdir(exist_ok=True, parents=True)
                    dacc.export_pyxcz(fp)
                    record[k].append(str(fp))
            if k == 'labels_overlay':
                fn = self.export_object_identities_overlay_map(where / k, prefix=prefix, **kp)
                record[k] = str(Path(k) / fn)
        record = {
            **record,
            **self.serialize(
                where,
                prefix=prefix,
                write_patches_to_subdirectory=params.write_patches_to_subdirectory,
            ),
        }
    def get_export_product_accessors(self, params: RoiSetExportParams) -> dict:
        """
        Return various representations of ROIs, e.g. patches, annotated stacks, and object maps, as accessors
        :param params: RoiSetExportParams object describing which products to export and with which parameters
        :return: ordered dict of accessors containing the specified products
        """
        interm = OrderedDict()
        if not self.count:
            return interm

        for k, kp in params.dict().items():
            if kp is None:
                continue
            if k == 'patches':
                for pk, pp in kp.items():
                    if pp['is_patch_mask']:
                        interm[f'patche_masks_{pk}'] = self.get_patch_masks_acc(**pp)
                    else:
                        interm[f'patches_{pk}'] = self.get_patches_acc(**pp)
            if k == 'annotated_zstacks':
                interm[k] = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kp))
            if k == 'object_classes':
                # pr = 'classify_by_'
                # cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)]
                for n in self.classification_columns:
                    interm[f'{k}_{n}'] = self.get_object_class_map(n)
                interm[k] = self.get_object_identities_overlay_map(**kp)
    def serialize(self, where: Path, prefix='roiset', allow_overwrite=True, write_patches_to_subdirectory=False) -> dict:
        """
        Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks
        :param where: path of directory in which to write files
        :param prefix: (optional) prefix
        :param allow_overwrite: freely overwrite CSV file of same name if True
        :param write_patches_to_subdirectory: if True, write all patches to a subdirectory with prefix as name
        :return: nested dict of Path objects describing the locations of export products
        """
        record = {}
        if not self._df.masks.binary_mask.apply(lambda x: np.all(x)).all():  # binary masks aren't just all True
            subdir = Path('tight_patch_masks')
            if write_patches_to_subdirectory:
                subdir = subdir / prefix
            se_exp = self.export_patch_masks(
                prefix=prefix,
                pad_to=None,
                expanded=False
            )
            # record patch masks paths to dataframe, then save static columns to CSV
            se_pa = se_exp.apply(
            self._df['patches', 'tight_patch_masks_path'] = se_exp.apply(lambda x: str(subdir / x))

        csv_path = where / 'dataframe' / (prefix + '.csv')
        if not allow_overwrite and csv_path.exists():
            raise SerializeRoiSetError(f'Cannot overwrite RoiSet file {csv_path.__str__()}')
        csv_path.parent.mkdir(parents=True, exist_ok=True)

        record['dataframe'] = str(Path('dataframe') / csv_path.name)
    def get_polygons(self, poly_threshold=0, dilation_radius=1) -> pd.DataFrame:
        Fit polygons to all object boundaries in the RoiSet
        :param poly_threshold: threshold distance for polygon fit; a smaller number follows sharp features more closely
        :param dilation_radius: radius of binary dilation to apply before fitting polygon
        :return: Series of (variable x 2) np.ndarrays describing (x, y) polymer coordinates
        """

        pad_to = 1

        def _poly_from_mask(roi):
            mask = roi.masks.binary_mask
            if len(mask.shape) != 2:
                raise PatchShapeError(f'Patch mask needs to be two dimensions to fit a polygon')

            # label and fill holes
            labeled = label(mask)
            filled = [rp.image_filled for rp in regionprops(labeled)]
            assert (np.unique(labeled)[-1] == 1) and (len(filled) == 1), 'Cannot fit multiple polygons in a single patch mask'

            closed = binary_dilation(filled[0], footprint=disk(dilation_radius))
            padded = np.pad(closed, pad_to) * 1.0
            all_contours = find_contours(padded)

            nc = len(all_contours)
            for j in range(0, nc):
                if all([points_in_poly(all_contours[k], all_contours[j]).all() for k in range(0, nc)]):
                    contour = all_contours[j]
                    break

            rel_polygon = approximate_polygon(contour[:, [1, 0]], poly_threshold) - [pad_to, pad_to]
            return rel_polygon + [roi.bounding_box.x0, roi.bounding_box.y0]

        return self._df.apply(_poly_from_mask, axis=1)

        return make_object_ids_from_df(self._df, self.acc_raw.shape_dict)
Christopher Randolph Rhodes's avatar
Christopher Randolph Rhodes committed
    def extract_features(
            self,
            extractor: callable,
            **kwargs,
    ):
        """
        Join a grouping of each Roi's features
        :param extractor: function that takes an RoiSet object and returns a DataFrame of features
        :param kwargs: variable-length keyword arguments that are passed to feature extractor
        """
Christopher Randolph Rhodes's avatar
Christopher Randolph Rhodes committed
        df_features = extractor(self, **kwargs)
        insert_level(df_features, 'features')
        self._df = self._df.join(df_features)


    def get_features(self) -> pd.DataFrame:
        return self.df().get('features')
    @classmethod
    def deserialize(cls, acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self:
        """
        Create an RoiSet object from saved files and an image accessor
        :param acc_raw: accessor to image that contains ROIs
        :param where: path to directory containing RoiSet serialization files, namely dataframe.csv and a subdirectory
            named tight_patch_masks
        :param prefix: starting prefix of patch mask filenames
        df = read_roiset_df(where / 'dataframe' / (prefix + '.csv'))
        is_3d = is_df_3d(df)
        ext = 'tif' if is_3d else 'png'

        if pa_masks.exists():  # import segmentation masks
            def _read_binary_mask(r):
                fname = f'{prefix}-la{r.name:04d}-zi{r.bounding_box.zi:04d}.{ext}'
                try:
                    ma_acc = generate_file_accessor(pa_masks / fname)
                        mask_data = ma_acc.data_yxz / ma_acc.dtype_max
                    else:
                        mask_data = ma_acc.data_yx / ma_acc.dtype_max
            df['masks', 'binary_mask'] = df.apply(_read_binary_mask, axis=1)
            id_mask = make_object_ids_from_df(df, acc_raw.shape_dict)
        else:  # assume bounding boxes, exclusively 2d objects
            df['bounding_box', 'y'] = df.bounding_box['y0']
            df['bounding_box', 'x'] = df.bounding_box['x0']
            df['bounding_box', 'h'] = df.bounding_box['y1'] - df.bounding_box['y0']
            df['bounding_box', 'w'] = df.bounding_box['x1'] - df.bounding_box['x0']
                df.bounding_box[['y', 'x', 'h', 'w']].to_dict(orient='records'),
                list(df.bounding_box['zi'])
class BoundingBoxError(Error):
    pass

class DerivedChannelError(Error):
class PatchShapeError(Error):
class ShapeMismatchError(Error):
    pass