From 1af591200248727678b2507b364892b809428b24 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 7 Feb 2024 15:51:35 +0100
Subject: [PATCH] Rearranged methods, condensed into roiset package

---
 model_server/base/process.py                  |  15 +-
 model_server/extensions/chaeo/accessors.py    |  32 +++
 model_server/extensions/chaeo/products.py     | 247 -----------------
 .../extensions/chaeo/{zmask.py => roiset.py}  | 251 ++++++++++++++++--
 .../extensions/chaeo/tests/test_zstack.py     |  42 ++-
 model_server/extensions/chaeo/workflows.py    |   4 +-
 6 files changed, 293 insertions(+), 298 deletions(-)
 rename model_server/extensions/chaeo/{zmask.py => roiset.py} (53%)

diff --git a/model_server/base/process.py b/model_server/base/process.py
index 48128223..3ee227d7 100644
--- a/model_server/base/process.py
+++ b/model_server/base/process.py
@@ -71,4 +71,17 @@ def rescale(nda, clip=0.0):
     clip_pct = (100.0 * clip, 100.0 * (1.0 - clip))
     cmin, cmax = np.percentile(nda, clip_pct)
     rescaled = rescale_intensity(nda, in_range=(cmin, cmax))
-    return rescaled
\ No newline at end of file
+    return rescaled
+
+
+def make_rgb(nda):
+    """
+    Convert a YXCZ stack array to RGB, and error if more than three channels
+    :param nda: np.ndarray (YXCZ dimensions)
+    :return: np.ndarray of 3-channel stack
+    """
+    h, w, c, nz = nda.shape
+    assert c <= 3
+    outdata = np.zeros((h, w, 3, nz), dtype=nda.dtype)
+    outdata[:, :, 0:c, :] = nda[:, :, :, :]
+    return outdata
diff --git a/model_server/extensions/chaeo/accessors.py b/model_server/extensions/chaeo/accessors.py
index bc65c6f3..1f1278a3 100644
--- a/model_server/extensions/chaeo/accessors.py
+++ b/model_server/extensions/chaeo/accessors.py
@@ -1,7 +1,10 @@
 from pathlib import Path
 
 import numpy as np
+from skimage.io import imsave
+from tifffile import imwrite
 
+from base.process import make_rgb
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor
 
 class MonoPatchStack(InMemoryDataAccessor):
@@ -127,6 +130,33 @@ class Multichannel3dPatchStack(InMemoryDataAccessor):
     def shape_dict(self):
         return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
 
+
+def write_patch_to_file(where, fname, yxcz):
+    ext = fname.split('.')[-1].upper()
+    where.mkdir(parents=True, exist_ok=True)
+
+    if ext == 'PNG':
+        assert yxcz.dtype == 'uint8', f'Invalid data type {yxcz.dtype}'
+        assert yxcz.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs'
+        assert yxcz.shape[3] == 1, f'Cannot export z-stacks as PNGs'
+        if yxcz.shape[2] == 1:
+            outdata = yxcz[:, :, 0, 0]
+        elif yxcz.shape[2] == 2: # add a blank blue channel
+            outdata = make_rgb(yxcz)
+        else: # preserve RGB order
+            outdata = yxcz[:, :, :, 0]
+        imsave(where / fname, outdata, check_contrast=False)
+        return True
+
+    elif ext in ['TIF', 'TIFF']:
+        zcyx = np.moveaxis(yxcz, [3, 2, 0, 1], [0, 1, 2, 3])
+        imwrite(where / fname, zcyx, imagej=True)
+        return True
+
+    else:
+        raise Exception(f'Unsupported file extension: {ext}')
+
+
 class Error(Exception):
     pass
 
@@ -135,3 +165,5 @@ class InvalidDataForPatchStackError(Error):
 
 class FileNotFoundError(Error):
     pass
+
+
diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py
index 484756fa..b28b04f6 100644
--- a/model_server/extensions/chaeo/products.py
+++ b/model_server/extensions/chaeo/products.py
@@ -1,250 +1,3 @@
-from math import floor, sqrt
-from pathlib import Path
 
-import numpy as np
-import pandas as pd
-from scipy.stats import moment
-from skimage.filters import sobel
-from skimage.io import imsave
-from skimage.measure import find_contours, shannon_entropy
-from tifffile import imwrite
 
-from model_server.extensions.chaeo.accessors import MonoPatchStack
-from model_server.extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch
-from model_server.base.accessors import InMemoryDataAccessor
-from model_server.base.process import pad, rescale, resample_to_8bit
 
-def _make_rgb(zs):
-    h, w, c, nz = zs.shape
-    assert c <= 3
-    outdata = np.zeros((h, w, 3, nz), dtype=zs.dtype)
-    outdata[:, :, 0:c, :] = zs[:, :, :, :]
-    return outdata
-
-def _focus_metrics():
-    return {
-        'max_intensity': lambda x: np.max(x),
-        'stdev': lambda x: np.std(x),
-        'max_sobel': lambda x: np.max(sobel(x)),
-        'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)),
-        'entropy': lambda x: shannon_entropy(x),
-        'moment': lambda x: moment(x.flatten(), moment=2),
-    }
-
-def _safe_add(a, g, b):
-    assert a.dtype == b.dtype
-    assert a.shape == b.shape
-    assert g >= 0.0
-
-    return np.clip(
-        a.astype('uint32') + g * b.astype('uint32'),
-        0,
-        np.iinfo(a.dtype).max
-    ).astype(a.dtype)
-
-def write_patch_to_file(where, fname, yxcz):
-    ext = fname.split('.')[-1].upper()
-    where.mkdir(parents=True, exist_ok=True)
-
-    if ext == 'PNG':
-        assert yxcz.dtype == 'uint8', f'Invalid data type {yxcz.dtype}'
-        assert yxcz.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs'
-        assert yxcz.shape[3] == 1, f'Cannot export z-stacks as PNGs'
-        if yxcz.shape[2] == 1:
-            outdata = yxcz[:, :, 0, 0]
-        elif yxcz.shape[2] == 2: # add a blank blue channel
-            outdata = _make_rgb(yxcz)
-        else: # preserve RGB order
-            outdata = yxcz[:, :, :, 0]
-        imsave(where / fname, outdata, check_contrast=False)
-        return True
-
-    elif ext in ['TIF', 'TIFF']:
-        zcyx = np.moveaxis(yxcz, [3, 2, 0, 1], [0, 1, 2, 3])
-        imwrite(where / fname, zcyx, imagej=True)
-        return True
-
-    else:
-        raise Exception(f'Unsupported file extension: {ext}')
-
-
-
-def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
-    patches = []
-    for roi in roiset:
-        patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
-        patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255
-
-        if pad_to:
-            patch = pad(patch, pad_to)
-
-        patches.append(patch)
-    return MonoPatchStack(patches)
-
-
-def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list:
-    patches_acc = get_patch_masks(roiset, pad_to=pad_to)
-
-    exported = []
-    for i, roi in enumerate(roiset):  # assumes index of patches_acc is same as dataframe
-        patch = patches_acc.iat_yxcz(i)
-        ext = 'png'
-        fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
-        write_patch_to_file(where, fname, patch)
-        exported.append(fname)
-    return exported
-
-
-def get_roiset_patches(
-        roiset,
-        rescale_clip: float = 0.0,
-        pad_to: int = 256,
-        make_3d: bool = False,
-        focus_metric: str = None,
-        rgb_overlay_channels: list = None,
-        rgb_overlay_weights: list = [1.0, 1.0, 1.0],
-        white_channel: int = None,
-        **kwargs
-) -> pd.DataFrame:
-
-    # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
-    raw = roiset.acc_raw
-    if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)):
-        assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None])
-        assert len(rgb_overlay_channels) == 3
-        assert len(rgb_overlay_weights) == 3
-
-        if white_channel:
-            assert white_channel < raw.chroma
-            stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
-        else:
-            stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype)
-
-        for ii, ci in enumerate(rgb_overlay_channels):
-            if ci is None:
-                continue
-            assert isinstance(ci, int)
-            assert ci < raw.chroma
-            stack[:, :, ii, :] = _safe_add(
-                stack[:, :, ii, :], # either black or grayscale channel
-                rgb_overlay_weights[ii],
-                raw.data[:, :, ci, :]
-            )
-    else:
-        if white_channel:  # interpret as just a single channel
-            assert white_channel < raw.chroma
-            annotate_rgb = False
-            for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']:
-                ca = kwargs.get(k)
-                if ca is None:
-                    continue
-                assert(ca < raw.chroma)
-                if ca != white_channel:
-                    annotate_rgb = True
-                    break
-            if annotate_rgb:  # make RGB patches anyway to include annotation color
-                stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
-            else: # make monochrome patches
-                stack = raw.data[:, :, [white_channel], :]
-        else:
-            stack = raw.data
-
-    def _make_patch(roi):
-        patch3d = stack[roi.slice]
-        ph, pw, pc, pz = patch3d.shape
-        subpatch = patch3d[roi.relative_slice]
-
-        # make a 3d patch
-        if make_3d:
-            patch = patch3d
-
-        # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
-        elif focus_metric is not None:
-            foc = _focus_metrics()[focus_metric]
-
-            patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)
-
-            for ci in range(0, pc):
-                me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)]
-                zif = np.argmax(me)
-                patch[:, :, ci, 0] = patch3d[:, :, ci, zif]
-
-        # make a 2d patch from middle of z-stack
-        else:
-            zim = floor(pz / 2)
-            patch = patch3d[:, :, :, [zim]]
-
-        assert len(patch.shape) == 4
-
-        if rescale_clip is not None:
-            patch = rescale(patch, rescale_clip)
-
-        if kwargs.get('draw_bounding_box') is True:
-            bci = kwargs.get('bounding_box_channel', 0)
-            assert bci < 3
-            if bci > 0:
-                patch = _make_rgb(patch)
-            for zi in range(0, patch.shape[3]):
-                patch[:, :, bci, zi] = draw_box_on_patch(
-                    patch[:, :, bci, zi],
-                    ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)),
-                    linewidth=kwargs.get('bounding_box_linewidth', 1)
-                )
-
-        if kwargs.get('draw_mask'):
-            mci = kwargs.get('mask_channel', 0)
-            mask = np.zeros(patch.shape[0:2], dtype=bool)
-            mask[roi.relative_slice[0:2]] = roi.mask
-            for zi in range(0, patch.shape[3]):
-                patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]
-
-        if kwargs.get('draw_contour'):
-            mci = kwargs.get('contour_channel', 0)
-            mask = np.zeros(patch.shape[0:2], dtype=bool)
-            mask[roi.relative_slice[0:2]] = roi.mask
-
-            for zi in range(0, patch.shape[3]):
-                patch[:, :, mci, zi] = draw_contours_on_patch(
-                    patch[:, :, mci, zi],
-                    find_contours(mask)
-                )
-
-        if pad_to:
-            patch = pad(patch, pad_to)
-        return patch
-
-    dfe = roiset.get_df()
-    dfe['patch'] = roiset.get_df().apply(lambda r: _make_patch(r), axis=1)
-    return dfe
-
-
-def export_patches_from_zstack(
-        where: Path,
-        roiset,
-        prefix='patch',
-        **kwargs
-):
-    make_3d = kwargs.get('make_3d', False)
-    patches_df = get_roiset_patches(roiset, **kwargs)
-
-    def _export_patch(roi):
-        patch = InMemoryDataAccessor(roi.patch)
-        ext = 'tif' if make_3d or patch.chroma > 3 else 'png'
-        fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
-
-        if patch.dtype is np.dtype('uint16'):
-            write_patch_to_file(where, fname, resample_to_8bit(patch.data))
-        else:
-            write_patch_to_file(where, fname, patch)
-
-        exported.append({
-            'df_index': roi.Index,
-            'patch_filename': fname,
-            'location': where.__str__(),
-        })
-
-    exported = []
-    for roi in patches_df.itertuples():  # just used for label info
-        _export_patch(roi)
-
-    return exported
\ No newline at end of file
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/roiset.py
similarity index 53%
rename from model_server/extensions/chaeo/zmask.py
rename to model_server/extensions/chaeo/roiset.py
index 7bb1d99c..c478ff25 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/roiset.py
@@ -1,34 +1,68 @@
+from math import sqrt, floor
+from pathlib import Path
 from uuid import uuid4
 
 import numpy as np
 import pandas as pd
+from scipy.stats import moment
+from skimage.filters import sobel
 
-from skimage.measure import label, regionprops_table
+from skimage.measure import label, regionprops_table, shannon_entropy, find_contours
 from sklearn.preprocessing import PolynomialFeatures
 from sklearn.linear_model import LinearRegression
 
 from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.base.models import InstanceSegmentationModel
-
-from model_server.extensions.chaeo.annotators import draw_boxes_on_3d_image
-from model_server.extensions.chaeo.products import export_patches_from_zstack
-from extensions.chaeo.params import RoiFilter, RoiSetMetaParams, RoiSetExportParams
-from model_server.extensions.chaeo.accessors import MonoPatchStack, Multichannel3dPatchStack
+from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb
+from model_server.extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
+from model_server.extensions.chaeo.params import RoiFilter, RoiSetMetaParams, RoiSetExportParams
+from model_server.extensions.chaeo.accessors import write_patch_to_file, MonoPatchStack, Multichannel3dPatchStack
 from model_server.extensions.chaeo.process import mask_largest_object
-from model_server.extensions.chaeo.products import get_roiset_patches, get_patch_masks, export_patch_masks
 
 
-def get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor:
+def _get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor:
     return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16'))
 
+
+def _focus_metrics():
+    return {
+        'max_intensity': lambda x: np.max(x),
+        'stdev': lambda x: np.std(x),
+        'max_sobel': lambda x: np.max(sobel(x)),
+        'rms_sobel': lambda x: sqrt(np.mean(sobel(x) ** 2)),
+        'entropy': lambda x: shannon_entropy(x),
+        'moment': lambda x: moment(x.flatten(), moment=2),
+    }
+
+
+def _safe_add(a, g, b):
+    assert a.dtype == b.dtype
+    assert a.shape == b.shape
+    assert g >= 0.0
+
+    return np.clip(
+        a.astype('uint32') + g * b.astype('uint32'),
+        0,
+        np.iinfo(a.dtype).max
+    ).astype(a.dtype)
+
+
 class RoiSet(object):
 
     def __init__(
             self,
-            acc_obj_ids: GenericImageDataAccessor,
             acc_raw: GenericImageDataAccessor,
+            acc_obj_ids: GenericImageDataAccessor,
             params: RoiSetMetaParams = RoiSetMetaParams(),
     ):
+        """
+        A set of regions of interest, referenced by their positions and contours in the YXCZ space of stack acc_raw.
+        RoiSet contains their internal state, which may be exported as patches, maps, and other products by export methods.
+        :param acc_raw: accessor to a generally a multichannel z-stack
+        :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity
+            labels its membership in a connected object
+        :param params: optional arguments that influence the definition and representation of ROIs
+        """
         assert acc_obj_ids.chroma == 1
         assert acc_obj_ids.nz == 1
         self.acc_obj_ids = acc_obj_ids
@@ -147,17 +181,11 @@ class RoiSet(object):
             projected = self.acc_raw.data.max(axis=-1)
         return projected
 
-    def get_patch_masks(self, **kwargs):
-        return get_patch_masks(self, **kwargs)
-
-    def export_patch_masks(self, where, **kwargs) -> list:
-        return export_patch_masks(self, where, **kwargs)
-
     def get_raw_patches(self, channel=None, pad_to=256, make_3d=False):  # padded, un-annotated 2d patches
         if channel:
-            patches_df = get_roiset_patches(self, white_channel=channel, pad_to=pad_to)
+            patches_df = self.get_patches(white_channel=channel, pad_to=pad_to)
         else:
-            patches_df = get_roiset_patches(self, pad_to=pad_to)
+            patches_df = self.get_patches(pad_to=pad_to)
         patches = list(patches_df['patch'])
         if channel is not None or self.acc_raw.chroma == 1:
             return MonoPatchStack(patches)
@@ -229,6 +257,179 @@ class RoiSet(object):
             om[self.acc_obj_ids.data == roi.label] = oc
         self.object_class_maps[name] = InMemoryDataAccessor(om)
 
+    def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list:
+        patches_acc = self.get_patch_masks(pad_to=pad_to)
+
+        exported = []
+        for i, roi in enumerate(self):  # assumes index of patches_acc is same as dataframe
+            patch = patches_acc.iat_yxcz(i)
+            ext = 'png'
+            fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
+            write_patch_to_file(where, fname, patch)
+            exported.append(fname)
+        return exported
+
+
+    def export_patches(self, where: Path, prefix='patch', **kwargs):
+        make_3d = kwargs.get('make_3d', False)
+        patches_df = self.get_patches(**kwargs)
+
+        def _export_patch(roi):
+            patch = InMemoryDataAccessor(roi.patch)
+            ext = 'tif' if make_3d or patch.chroma > 3 else 'png'
+            fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
+
+            if patch.dtype is np.dtype('uint16'):
+                write_patch_to_file(where, fname, resample_to_8bit(patch.data))
+            else:
+                write_patch_to_file(where, fname, patch)
+
+            exported.append({
+                'df_index': roi.Index,
+                'patch_filename': fname,
+                'location': where.__str__(),
+            })
+
+        exported = []
+        for roi in patches_df.itertuples():  # just used for label info
+            _export_patch(roi)
+
+        return exported
+
+    def get_patch_masks(self, pad_to: int = 256) -> MonoPatchStack:
+        patches = []
+        for roi in self:
+            patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
+            patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255
+
+            if pad_to:
+                patch = pad(patch, pad_to)
+
+            patches.append(patch)
+        return MonoPatchStack(patches)
+
+    def get_patches(
+            self,
+            rescale_clip: float = 0.0,
+            pad_to: int = 256,
+            make_3d: bool = False,
+            focus_metric: str = None,
+            rgb_overlay_channels: list = None,
+            rgb_overlay_weights: list = [1.0, 1.0, 1.0],
+            white_channel: int = None,
+            **kwargs
+    ) -> pd.DataFrame:
+
+        # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
+        raw = self.acc_raw
+        if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)):
+            assert all([c < raw.chroma for c in rgb_overlay_channels if c is not None])
+            assert len(rgb_overlay_channels) == 3
+            assert len(rgb_overlay_weights) == 3
+
+            if white_channel:
+                assert white_channel < raw.chroma
+                stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
+            else:
+                stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype)
+
+            for ii, ci in enumerate(rgb_overlay_channels):
+                if ci is None:
+                    continue
+                assert isinstance(ci, int)
+                assert ci < raw.chroma
+                stack[:, :, ii, :] = _safe_add(
+                    stack[:, :, ii, :],  # either black or grayscale channel
+                    rgb_overlay_weights[ii],
+                    raw.data[:, :, ci, :]
+                )
+        else:
+            if white_channel:  # interpret as just a single channel
+                assert white_channel < raw.chroma
+                annotate_rgb = False
+                for k in ['contour_channel', 'bounding_box_channel', 'mask_channel']:
+                    ca = kwargs.get(k)
+                    if ca is None:
+                        continue
+                    assert (ca < raw.chroma)
+                    if ca != white_channel:
+                        annotate_rgb = True
+                        break
+                if annotate_rgb:  # make RGB patches anyway to include annotation color
+                    stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
+                else:  # make monochrome patches
+                    stack = raw.data[:, :, [white_channel], :]
+            else:
+                stack = raw.data
+
+        def _make_patch(roi):
+            patch3d = stack[roi.slice]
+            ph, pw, pc, pz = patch3d.shape
+            subpatch = patch3d[roi.relative_slice]
+
+            # make a 3d patch
+            if make_3d:
+                patch = patch3d
+
+            # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
+            elif focus_metric is not None:
+                foc = _focus_metrics()[focus_metric]
+
+                patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype)
+
+                for ci in range(0, pc):
+                    me = [foc(subpatch[:, :, ci, zi]) for zi in range(0, pz)]
+                    zif = np.argmax(me)
+                    patch[:, :, ci, 0] = patch3d[:, :, ci, zif]
+
+            # make a 2d patch from middle of z-stack
+            else:
+                zim = floor(pz / 2)
+                patch = patch3d[:, :, :, [zim]]
+
+            assert len(patch.shape) == 4
+
+            if rescale_clip is not None:
+                patch = rescale(patch, rescale_clip)
+
+            if kwargs.get('draw_bounding_box') is True:
+                bci = kwargs.get('bounding_box_channel', 0)
+                assert bci < 3
+                if bci > 0:
+                    patch = make_rgb(patch)
+                for zi in range(0, patch.shape[3]):
+                    patch[:, :, bci, zi] = draw_box_on_patch(
+                        patch[:, :, bci, zi],
+                        ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)),
+                        linewidth=kwargs.get('bounding_box_linewidth', 1)
+                    )
+
+            if kwargs.get('draw_mask'):
+                mci = kwargs.get('mask_channel', 0)
+                mask = np.zeros(patch.shape[0:2], dtype=bool)
+                mask[roi.relative_slice[0:2]] = roi.mask
+                for zi in range(0, patch.shape[3]):
+                    patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]
+
+            if kwargs.get('draw_contour'):
+                mci = kwargs.get('contour_channel', 0)
+                mask = np.zeros(patch.shape[0:2], dtype=bool)
+                mask[roi.relative_slice[0:2]] = roi.mask
+
+                for zi in range(0, patch.shape[3]):
+                    patch[:, :, mci, zi] = draw_contours_on_patch(
+                        patch[:, :, mci, zi],
+                        find_contours(mask)
+                    )
+
+            if pad_to:
+                patch = pad(patch, pad_to)
+            return patch
+
+        dfe = self._df
+        dfe['patch'] = self._df.apply(lambda r: _make_patch(r), axis=1)
+        return dfe
+
     def run_exports(self, where, channel, prefix, params: RoiSetExportParams):
         if not self.count:
             return
@@ -240,17 +441,17 @@ class RoiSet(object):
             if kp is None:
                 continue
             if k == 'patches_3d':
-                files = export_patches_from_zstack(
-                    subdir, self, white_channel=channel, prefix=pr, make_3d=True, **kp
+                files = self.export_patches(
+                    subdir, white_channel=channel, prefix=pr, make_3d=True, **kp
                 )
             if k == 'annotated_patches_2d':
-                files = export_patches_from_zstack(
-                    subdir, self, prefix=pr, make_3d=False, white_channel=channel,
+                files = self.export_patches(
+                    subdir, prefix=pr, make_3d=False, white_channel=channel,
                     bounding_box_channel=1, bounding_box_linewidth=2, **kp,
                 )
             if k == 'patches_2d':
-                files = export_patches_from_zstack(
-                    subdir, self, white_channel=channel, prefix=pr, make_3d=False, **kp
+                files = self.export_patches(
+                    subdir, white_channel=channel, prefix=pr, make_3d=False, **kp
                 )
                 df_patches = pd.DataFrame(files)
                 self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
@@ -313,3 +514,7 @@ def project_stack_from_focal_points(
         ),
         axis=3
     )
+
+
+
+
diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py
index 9631dd04..f3b82997 100644
--- a/model_server/extensions/chaeo/tests/test_zstack.py
+++ b/model_server/extensions/chaeo/tests/test_zstack.py
@@ -7,9 +7,8 @@ from model_server.conf.testing import output_path
 
 from model_server.extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params
 from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
-from model_server.extensions.chaeo.products import export_patches_from_zstack
 from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack
-from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet
+from model_server.extensions.chaeo.roiset import _get_label_ids, RoiSet
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.extensions.ilastik.models import IlastikPixelClassifierModel
 from model_server.base.models import DummyInstanceSegmentationModel
@@ -42,10 +41,10 @@ class BaseTestRoiSetMonoProducts(object):
 class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def _make_roi_set(self, mask_type='boxes', **kwargs):
-        id_map = get_label_ids(self.seg_mask)
+        id_map = _get_label_ids(self.seg_mask)
         roiset = RoiSet(
-            id_map,
             self.stack_ch_pa,
+            id_map,
             params=RoiSetMetaParams(
                 mask_type=mask_type,
                 filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}),
@@ -78,9 +77,9 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
     def test_roiset_from_non_zstacks(self, **kwargs):
         acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
         self.assertEqual(acc_zstack_slice.nz, 1)
-        id_map = get_label_ids(self.seg_mask)
+        id_map = _get_label_ids(self.seg_mask)
 
-        roiset = RoiSet(id_map, acc_zstack_slice, params=RoiSetMetaParams(mask_type='boxes'))
+        roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes'))
         zmask = roiset.get_zmask()
 
         zmask_acc = InMemoryDataAccessor(zmask)
@@ -105,18 +104,16 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def test_make_2d_patches(self):
         roiset = self._make_roi_set()
-        files = export_patches_from_zstack(
+        files = roiset.export_patches(
             output_path / '2d_patches',
-            roiset,
             draw_bounding_box=True,
         )
         self.assertGreaterEqual(len(files), 1)
 
     def test_make_3d_patches(self):
         roiset = self._make_roi_set()
-        files = export_patches_from_zstack(
+        files = roiset.export_patches(
             output_path / '3d_patches',
-            roiset,
             make_3d=True)
         self.assertGreaterEqual(len(files), 1)
 
@@ -129,12 +126,12 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertEqual(result.shape, roiset.acc_raw.shape)
 
     def test_flatten_image(self):
-        id_map = get_label_ids(self.seg_mask)
+        id_map = _get_label_ids(self.seg_mask)
 
-        roiset = RoiSet(id_map, self.stack_ch_pa, params=RoiSetMetaParams(mask_type='boxes'))
+        roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes'))
         df = roiset.get_df()
 
-        from model_server.extensions.chaeo.zmask import project_stack_from_focal_points
+        from model_server.extensions.chaeo.roiset import project_stack_from_focal_points
 
         img = project_stack_from_focal_points(
             df['centroid-0'].to_numpy(),
@@ -227,10 +224,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
     def setUp(self) -> None:
         super().setUp()
-        id_map = get_label_ids(self.seg_mask)
+        id_map = _get_label_ids(self.seg_mask)
         self.roiset = RoiSet(
-            id_map,
             self.stack,
+            id_map,
             params=RoiSetMetaParams(
                 expand_box_by=(128, 2),
                 mask_type='boxes',
@@ -239,9 +236,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         )
 
     def test_multichannel_to_mono_2d_patches(self):
-        files = export_patches_from_zstack(
+        files = self.roiset.export_patches(
             output_path / 'multichannel' / 'mono_2d_patches',
-            self.roiset,
             white_channel=3,
             draw_bounding_box=True,
         )
@@ -249,9 +245,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         self.assertEqual(result.chroma, 1)
 
     def test_multichannnel_to_mono_2d_patches_rgb_bbox(self):
-        files = export_patches_from_zstack(
+        files = self.roiset.export_patches(
             output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox',
-            self.roiset,
             white_channel=3,
             draw_bounding_box=True,
             bounding_box_channel=1,
@@ -260,9 +255,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         self.assertEqual(result.chroma, 3)
 
     def test_multichannnel_to_rgb_2d_patches_bbox(self):
-        files = export_patches_from_zstack(
+        files = self.roiset.export_patches(
             output_path / 'multichannel' / 'rgb_2d_patches_bbox',
-            self.roiset,
             white_channel=4,
             rgb_overlay_channels=(3, None, None),
             draw_mask=True,
@@ -273,9 +267,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         self.assertEqual(result.chroma, 3)
 
     def test_multichannnel_to_rgb_2d_patches_contour(self):
-        files = export_patches_from_zstack(
+        files = self.roiset.export_patches(
             output_path / 'multichannel' / 'rgb_2d_patches_contour',
-            self.roiset,
             rgb_overlay_channels=(3, None, None),
             draw_contour=True,
             contour_channel=1,
@@ -286,9 +279,8 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         self.assertEqual(result.get_one_channel_data(2).data.max(), 0)  # blue channel is black
 
     def test_multichannel_to_multichannel_tif_patches(self):
-        files = export_patches_from_zstack(
+        files = self.roiset.export_patches(
             output_path / 'multichannel' / 'multichannel_tif_patches',
-            self.roiset,
         )
         result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename'])
         self.assertEqual(result.chroma, 5)
diff --git a/model_server/extensions/chaeo/workflows.py b/model_server/extensions/chaeo/workflows.py
index 7e90436b..08d06adf 100644
--- a/model_server/extensions/chaeo/workflows.py
+++ b/model_server/extensions/chaeo/workflows.py
@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
 
 from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
 from model_server.extensions.chaeo.process import mask_largest_object
-from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet
+from model_server.extensions.chaeo.roiset import _get_label_ids, RoiSet
 
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
@@ -48,7 +48,7 @@ def infer_object_map_from_zstack(
     ti.click('classify_pixels')
 
     # make zmask
-    rois = RoiSet(get_label_ids(mip_mask), stack, params=roi_params)
+    rois = RoiSet(stack, _get_label_ids(mip_mask), params=roi_params)
     ti.click('generate_zmasks')
 
     rois.classify_by(
-- 
GitLab