diff --git a/conf/testing.py b/conf/testing.py
index 33cc08595f00e9176e0207e793b7d65b8439b9df..9a08092ffaa2eafc1868fb545a398b972d86ba95 100644
--- a/conf/testing.py
+++ b/conf/testing.py
@@ -56,7 +56,7 @@ monozstackmask = {
 ilastik_classifiers = {
     'px': root / 'ilastik' / 'demo_px.ilp',
     'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp',
-    'seg_to_obj': root / 'ilastik' / 'new_auto_obj.ilp',
+    'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp',
 }
 
 output_path = root / 'testing_output'
diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py
index 864413d814bb6bf8073287ac8be5cb93325729e7..f39a78f87363d290db226cfe9ee811e13af778c8 100644
--- a/extensions/ilastik/models.py
+++ b/extensions/ilastik/models.py
@@ -92,9 +92,16 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
         return ObjectClassificationWorkflowBinary
 
     def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
-        assert segmentation_img.is_mask()
         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
-        tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
+        assert segmentation_img.is_mask()
+        if segmentation_img.dtype == 'bool':
+            seg = 255 * segmentation_img.data.astype('uint8')
+            tagged_seg_data = vigra.taggedView(
+                255 * segmentation_img.data.astype('uint8'),
+                'yxcz'
+            )
+        else:
+            tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
 
         dsi = [
             {
@@ -105,7 +112,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
 
         obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
 
-        assert (len(obmaps) == 1, 'ilastik generated more than one object map')
+        assert len(obmaps) == 1, 'ilastik generated more than one object map'
 
         yxcz = np.moveaxis(
             obmaps[0],
@@ -141,7 +148,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
 
         obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
 
-        assert (len(obmaps) == 1, 'ilastik generated more than one object map')
+        assert len(obmaps) == 1, 'ilastik generated more than one object map'
 
         yxcz = np.moveaxis(
             obmaps[0],
diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py
index ed74c43e8293a743876803315f3adfd964dfaa6b..1dfbde1685218234f152ace5968a1aae4a8a9fff 100644
--- a/extensions/ilastik/tests/test_ilastik.py
+++ b/extensions/ilastik/tests/test_ilastik.py
@@ -82,7 +82,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.mono_image = mono_image
         self.mask = mask
 
-    def test_run_object_classifier(self):
+    def test_run_object_classifier_from_pixel_predictions(self):
         self.test_run_pixel_classifier()
         fp = czifile['path']
         model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
@@ -98,6 +98,22 @@ class TestIlastikPixelClassification(unittest.TestCase):
         )
         self.assertEqual(objmap.data.max(), 3)
 
+    def test_run_object_classifier_from_segmentation(self):
+        self.test_run_pixel_classifier()
+        fp = czifile['path']
+        model = ilm.IlastikObjectClassifierFromSegmentationModel(
+            {'project_file': ilastik_classifiers['seg_to_obj']}
+        )
+        objmap = model.label_instance_class(self.mono_image, self.mask)
+
+        self.assertTrue(
+            write_accessor_data_to_file(
+                output_path / f'obmap_from_seg_{fp.stem}.tif',
+                objmap,
+            )
+        )
+        self.assertEqual(objmap.data.max(), 3)
+
     def test_ilastik_pixel_classification_as_workflow(self):
         result = infer_image_to_image(
             czifile['path'],
diff --git a/model_server/workflows.py b/model_server/workflows.py
index c81d2afdc3e59e8153e1710524e286961ecd7746..6eb7b50d1b6e78518cc500ac5c1936cd266a333b 100644
--- a/model_server/workflows.py
+++ b/model_server/workflows.py
@@ -43,7 +43,7 @@ def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs)
     img = generate_file_accessor(fpi).get_one_channel_data(ch)
     ti.click('file_input')
 
-    outdata, _ = model.infer(img)
+    outdata = model.label_pixel_class(img)
     ti.click('inference')
 
     outpath = where_output / (model.model_id + '_' + fpi.stem + '.tif')