diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py
index 45256258d66cd628590615fd804c5b3397c99272..541d0404d241e61f2be82d5c523bfb58c8847b63 100644
--- a/extensions/chaeo/tests/test_zstack.py
+++ b/extensions/chaeo/tests/test_zstack.py
@@ -209,6 +209,19 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             DummyInstanceSegmentationModel(),
         ]
 
+        models = {
+            'pixel_classifier': {
+                'model': self.pxmodel,
+                'params': {
+                    'px_class': 0,
+                    'px_prob_threshold': 0.6,
+                }
+            },
+            'object_classifier': {
+                'model': DummyInstanceSegmentationModel(),
+            }
+        }
+
         roi_params = RoiSetMetaParams(**{
             'mask_type': 'boxes',
             'filters': {
@@ -240,8 +253,6 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             multichannel_zstack['path'],
             output_path / 'roiset' / 'workflow',
             models,
-            pixel_class=pp['pxmap_channel'],
-            pixel_probability_threshold=pp['pxmap_threshold'],
             segmentation_channel=pp['segmentation_channel'],
             patches_channel=pp['patches_channel'],
             export_params=export_params,
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index 7a0e09413fa30a2648ff8941589a9e5f1126880a..619f7c95be2868327d2654a327a913cef7dba81f 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -26,17 +26,12 @@ def infer_object_map_from_zstack(
         segmentation_channel: int,
         patches_channel: int,
         zmask_zindex: int = None,  # None for MIP,
-        zmask_clip: int = None,
         roi_params: RoiSetMetaParams = RoiSetMetaParams(),
         export_params: RoiSetExportParams = RoiSetExportParams(),
-        pixel_class=0,
-        pixel_probability_threshold=0.6,
 ) -> Dict:
     assert len(models) == 2
-    pixel_classifier = models[0]
-    assert isinstance(pixel_classifier, SemanticSegmentationModel)
-    object_classifier = models[1]
-    assert isinstance(object_classifier, InstanceSegmentationModel)
+    assert isinstance(models['pixel_classifier']['model'], SemanticSegmentationModel)
+    assert isinstance(models['object_classifier']['model'], InstanceSegmentationModel)
 
     ti = Timer()
     stack = generate_file_accessor(Path(input_file_path))
@@ -49,20 +44,16 @@ def infer_object_map_from_zstack(
         zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex]
     else:
         zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True)
-    if zmask_clip:
-        zmask_data = rescale(zmask_data, zmask_clip)
-    mip = InMemoryDataAccessor(
-        zmask_data,
-    )
+    mip = InMemoryDataAccessor(zmask_data)
 
-    mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,)
+    mip_mask = models['pixel_classifier']['model'].label_pixel_class(mip, **models['pixel_classifier']['params'])
     ti.click('classify_pixels')
 
     # make zmask
     rois = RoiSet(mip_mask, stack, params=roi_params)
     ti.click('generate_zmasks')
 
-    rois.classify_by(patches_channel, object_classifier)
+    rois.classify_by(patches_channel, models['object_classifier']['model'])
     ti.click('classify_objects')
 
     rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params)
diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py
index f39a78f87363d290db226cfe9ee811e13af778c8..ac8d253b6ac44c1aaddb98b8350b180eef49b437 100644
--- a/extensions/ilastik/models.py
+++ b/extensions/ilastik/models.py
@@ -77,9 +77,9 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
         )
         return InMemoryDataAccessor(data=yxcz), {'success': True}
 
-    def label_pixel_class(self, img: GenericImageDataAccessor, pixel_class: int = 0, pixel_probability_threshold=0.5):
+    def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs):
         pxmap, _ = self.infer(img)
-        mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold
+        mask = pxmap.data[:, :, px_class, :] > px_prob_threshold
         return InMemoryDataAccessor(mask)