diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py
index 97a13aaf75192c76a75ace8c8bcccf3678370a5d..5fdea52941b420f18849aef78454a3b9a200da45 100644
--- a/model_server/conf/testing.py
+++ b/model_server/conf/testing.py
@@ -57,6 +57,9 @@ ilastik_classifiers = {
     'px': root / 'ilastik' / 'demo_px.ilp',
     'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp',
     'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp',
+    'px_color_zstack': root / 'ilastik' / 'px-3d-color.ilp',
+    'ob_pxmap_color_zstack': root / 'ilastik' / 'ob-pxmap-color-zstack.ilp',
+    'ob_seg_color_zstack': root / 'ilastik' / 'ob-seg-color-zstack.ilp',
 }
 
 roiset_test_data = {
diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 2dea13329513d0256baf7e9f2c698d4731301893..2c92c026c08d78052b21b982faa7aa4b4b7e1ec0 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -104,8 +104,15 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
         return [l.decode() for l in h5['PixelClassification/LabelNames'][()]]
 
     def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict):
-        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
-            raise IlastikInputShapeError()
+        if self.model_chroma != input_img.chroma:
+            raise IlastikInputShapeError(
+                f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
+            )
+        if self.model_3d != input_img.is_3d():
+            if self.model_3d:
+                raise IlastikInputShapeError(f'Model is 3D but input image is 2D')
+            else:
+                raise IlastikInputShapeError(f'Model is 2D but input image is 3D')
 
         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
         dsi = [
@@ -164,7 +171,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
     def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
         if self.model_chroma != input_img.chroma:
             raise IlastikInputShapeError(
-                f'Model {self} expects {self.model_chroma} input channels but received only {input_img.chroma}'
+                f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
             )
         if self.model_3d != input_img.is_3d():
             if self.model_3d:
@@ -229,8 +236,15 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
         return ObjectClassificationWorkflowPrediction
 
     def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
-        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
-            raise IlastikInputShapeError()
+        if self.model_chroma != input_img.chroma:
+            raise IlastikInputShapeError(
+                f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
+            )
+        if self.model_3d != input_img.is_3d():
+            if self.model_3d:
+                raise IlastikInputShapeError(f'Model is 3D but input image is 2D')
+            else:
+                raise IlastikInputShapeError(f'Model is 2D but input image is 3D')
 
         if isinstance(input_img, PatchStack):
             assert isinstance(pxmap_img, PatchStack)
diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py
index a1a4983efbded47e54804521c04394c55795439c..18e95a4df44279b452bc11226700375b63c15381 100644
--- a/model_server/extensions/ilastik/tests/test_ilastik.py
+++ b/model_server/extensions/ilastik/tests/test_ilastik.py
@@ -7,6 +7,7 @@ import numpy as np
 from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data
 from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
 from model_server.extensions.ilastik import models as ilm
+from model_server.extensions.ilastik.workflows import infer_px_then_ob_model
 from model_server.base.models import InvalidObjectLabelsError
 from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
 from model_server.base.workflows import classify_pixels
@@ -320,6 +321,92 @@ class TestIlastikOverApi(TestServerBaseClass):
         self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
 
 
+class TestIlastikOnMultichannelInputs(TestServerBaseClass):
+    def setUp(self) -> None:
+        super(TestIlastikOnMultichannelInputs, self).setUp()
+        self.pa_px_classifier = ilastik_classifiers['px_color_zstack']
+        self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack']
+        self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack']
+        self.pa_input_image = roiset_test_data['multichannel_zstack']['path']
+        self.pa_mask = roiset_test_data['multichannel_zstack']['mask_path_3d']
+
+    def _copy_input_file_to_server(self):
+        from shutil import copyfile
+
+        pa_data = roiset_test_data['multichannel_zstack']['path']
+        resp = self._get('paths')
+        pa = resp.json()['inbound_images']
+        outpath = pathlib.Path(pa) / pa_data.name
+        copyfile(
+            czifile['path'],
+            outpath
+        )
+
+    def test_classify_pixels(self):
+        img = generate_file_accessor(self.pa_input_image)
+        self.assertGreater(img.chroma, 1)
+        mod = ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier})
+        pxmap = mod.infer(img)[0]
+        self.assertEqual(pxmap.hw, img.hw)
+        self.assertEqual(pxmap.nz, img.nz)
+        return pxmap
+
+    def test_classify_objects(self):
+        pxmap = self.test_classify_pixels()
+        img = generate_file_accessor(self.pa_input_image)
+        mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier})
+        obmap = mod.infer(img, pxmap)[0]
+        self.assertEqual(obmap.hw, img.hw)
+        self.assertEqual(obmap.nz, img.nz)
+
+    def _call_workflow(self, channel):
+        return infer_px_then_ob_model(
+            self.pa_input_image,
+            ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}),
+            ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier}),
+            output_path,
+            channel=channel,
+        )
+
+    def test_workflow(self):
+        with self.assertRaises(ilm.IlastikInputShapeError):
+            self._call_workflow(channel=0)
+        res = self._call_workflow(channel=None)
+        acc_input = generate_file_accessor(self.pa_input_image)
+        acc_obmap = generate_file_accessor(res.object_map_filepath)
+        self.assertEqual(acc_obmap.hw, acc_input.hw)
+        self.assertEqual(len(acc_obmap._unique()[1]), 3)
+
+
+    def test_api(self):
+        resp_load = self._put(
+            'ilastik/seg/load/',
+            query={'project_file': str(self.pa_px_classifier)},
+        )
+        self.assertEqual(resp_load.status_code, 200, resp_load.json())
+        px_model_id = resp_load.json()['model_id']
+
+        resp_load = self._put(
+            'ilastik/pxmap_to_obj/load/',
+            query={'project_file': str(self.pa_ob_pxmap_classifier)},
+        )
+        self.assertEqual(resp_load.status_code, 200, resp_load.json())
+        ob_model_id = resp_load.json()['model_id']
+
+        resp_infer = self._put(
+            'ilastik/pixel_then_object_classification/infer/',
+            query={
+                'px_model_id': px_model_id,
+                'ob_model_id': ob_model_id,
+                'input_filename': self.pa_input_image,
+            }
+        )
+        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
+        acc_input = generate_file_accessor(self.pa_input_image)
+        acc_obmap = generate_file_accessor(resp_infer.json()['object_map_filepath'])
+        self.assertEqual(acc_obmap.hw, acc_input.hw)
+
+
 class TestIlastikObjectClassification(unittest.TestCase):
     def setUp(self):
         stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])
diff --git a/model_server/extensions/ilastik/workflows.py b/model_server/extensions/ilastik/workflows.py
index c5b1575f00c1e27f1fc977ebd709dff48a16681d..6f913f652d013145fdcc2a9f91d3ed0db4234eba 100644
--- a/model_server/extensions/ilastik/workflows.py
+++ b/model_server/extensions/ilastik/workflows.py
@@ -26,6 +26,7 @@ def infer_px_then_ob_model(
         px_model: IlastikPixelClassifierModel,
         ob_model: IlastikObjectClassifierFromPixelPredictionsModel,
         where_output: Path,
+        channel: int = None,
         **kwargs
 ) -> WorkflowRunRecord:
     """
@@ -35,6 +36,7 @@ def infer_px_then_ob_model(
     :param px_model: model instance for pixel classification
     :param ob_model: model instance for object classification
     :param where_output: Path object that references output image directory
+    :param channel: input image channel to pass to pixel classification, or all channels if None
     :param kwargs: variable-length keyword arguments
     :return:
     """
@@ -42,8 +44,12 @@ def infer_px_then_ob_model(
     assert isinstance(ob_model, IlastikObjectClassifierFromPixelPredictionsModel)
 
     ti = Timer()
-    ch = kwargs.get('channel')
-    img = generate_file_accessor(fpi).get_one_channel_data(ch, mip=kwargs.get('mip', False))
+    raw_acc = generate_file_accessor(fpi)
+    if channel is not None:
+        channels = [channel]
+    else:
+        channels = range(0, raw_acc.chroma)
+    img = raw_acc.get_channels(channels, mip=kwargs.get('mip', False))
     ti.click('file_input')
 
     px_map, _ = px_model.infer(img)