diff --git a/extensions/chaeo/params.py b/extensions/chaeo/params.py
index 796469e9d7cd52b9af9c3d852b0af1dd1f10a0bd..fccc4f0bae2943ed0ab63f2a1a298d15dda71e0f 100644
--- a/extensions/chaeo/params.py
+++ b/extensions/chaeo/params.py
@@ -11,6 +11,7 @@ class PatchParams(BaseModel):
     focus_metric: str = 'max_sobel'
     rgb_overlay_channels: List[Union[int, None]] = [None, None, None]
     rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0]
+    pad_to: int = 256
 
 
 class AnnotatedZStackParams(BaseModel):
@@ -18,6 +19,7 @@ class AnnotatedZStackParams(BaseModel):
 
 
 class RoiSetExportParams(BaseModel):
+    expand_box_by: List[int] = [128, 0]
     pixel_probabilities: bool = False
     patches_3d: Union[PatchParams, None] = None
     patches_2d_for_annotation: Union[PatchParams, None] = None
diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py
index 357a0ca6feba201097a27782c4ca4d4717f3729f..a78fbea4e9b4721fb19d10c60c86688cbdc15edb 100644
--- a/extensions/chaeo/tests/test_zstack.py
+++ b/extensions/chaeo/tests/test_zstack.py
@@ -211,18 +211,27 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
         export_params = RoiSetExportParams(**{
             'pixel_probabilities': True,
             'patches_3d': {},
-            # 'patches_2d_for_annotation': {
-            #     'draw_bounding_box': True,
-            #     'rgb_overlay_channels': [4, 2, None]
-            # }
+            'patches_2d_for_annotation': {
+                'draw_bounding_box': True,
+                'rgb_overlay_channels': [3, None, None],
+                'rgb_overlay_weights': [0.2, 1.0, 1.0],
+                'pad_to': 512,
+            },
+            'patches_2d_for_training': {
+                'draw_bounding_box': False,
+                'draw_mask': False,
+            },
+            'patch_masks': True,
+            'annotated_z_stack': {}
         })
         infer_object_map_from_zstack(
             multichannel_zstack['path'],
-            output_path / 'workflow',
+            output_path / 'roiset' / 'workflow',
             models,
             pxmap_foreground_channel=pp['pxmap_channel'],
             pxmap_threshold=pp['pxmap_threshold'],
             segmentation_channel=pp['segmentation_channel'],
             patches_channel=pp['patches_channel'],
-            # exports=export_params,
-        )
\ No newline at end of file
+            exports=export_params,
+        )
+
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index 46e693b159ea33d2b6b431ef66b3554a9d71ad77..91b09c9d45178659a740efa85551b9caa6486776 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -286,7 +286,7 @@ def infer_object_map_from_zstack(
         stack,
         mask_type=zmask_type,
         filters=zmask_filters,
-        expand_box_by=kwargs.get('zmask_expand_box_by', (0, 0)),
+        expand_box_by=exports.expand_box_by,
     )
     ti.click('generate_zmasks')
 
diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py
index 886dcf1c7ab5a713d5064706441f5ac2707b0be8..31bf51702d539affc825cea8e0e437ca28b00e8d 100644
--- a/extensions/chaeo/zmask.py
+++ b/extensions/chaeo/zmask.py
@@ -53,36 +53,29 @@ class RoiSet(object):
             make_3d=True,
         )
 
-    def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams):
+    def export_2d_patches_for_training(self, where, prefix, channel, params: RoiSetExportParams):
         if not self.count:
             return
         files = export_multichannel_patches_from_zstack(
             where,
-            self.acc_raw.get_one_channel_data(channel),
+            self.acc_raw,
             self.zmask_meta,
+            ch_white=channel,
             prefix=prefix,
-            draw_bounding_box=params.draw_bounding_box,
             rescale_clip=params.rescale_clip,
             make_3d=False,
             focus_metric=params.focus_metric,
-            ch_white=channel,
-            ch_rgb_overlay=params.rgb_overlay_channels,
-            bounding_box_channel=1,
-            bounding_box_linewidth=2,
-            draw_contour=params.draw_contour,
-            draw_mask=params.draw_mask,
-            overlay_gain=params.rgb_overlay_weights,
         )
         df_patches = pd.DataFrame(files)
         self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
         self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1)
 
-    def export_2d_patches_for_training(self, where, prefix, channel, params: RoiSetExportParams):
+    def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams):
         if not self.count:
             return
         files = export_multichannel_patches_from_zstack(
             where,
-            self.acc_raw.get_one_channel_data(channel),
+            self.acc_raw,
             self.zmask_meta,
             prefix=prefix,
             draw_bounding_box=params.draw_bounding_box,
@@ -96,22 +89,7 @@ class RoiSet(object):
             draw_contour=params.draw_contour,
             draw_mask=params.draw_mask,
             overlay_gain=params.rgb_overlay_weights,
-        )
-        df_patches = pd.DataFrame(files)
-        self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
-        self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1)
-
-    def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams):
-        if not self.count:
-            return
-        files = export_multichannel_patches_from_zstack(
-            where,
-            self.acc_raw.get_one_channel_data(channel),
-            self.zmask_meta,
-            prefix=prefix,
-            rescale_clip=params.rescale_clip,
-            make_3d=False,
-            focus_metric=params.focus_metric,
+            pad_to=params.pad_to,
         )
 
     def export_patch_masks(self, where, prefix, channel, params: RoiSetExportParams):
@@ -152,7 +130,10 @@ class RoiSet(object):
         return projected
 
     def get_raw_patches(self, channel):
-        return get_patches_from_zmask_meta(self.acc_raw, self.zmask_meta).get_one_channel_data(channel)
+        return get_patches_from_zmask_meta(
+            self.acc_raw.get_one_channel_data(channel),
+            self.zmask_meta
+        )
 
     def get_patch_masks(self):
         return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta)