diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py
index 9b55d6d0eeec83f1bab962f18d81fd275eee5b0a..864413d814bb6bf8073287ac8be5cb93325729e7 100644
--- a/extensions/ilastik/models.py
+++ b/extensions/ilastik/models.py
@@ -6,7 +6,7 @@ import vigra
 
 import extensions.ilastik.conf
 from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
-from model_server.models import Model, InstanceSegmentationModel, ParameterExpectedError, SemanticSegmentationModel
+from model_server.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
 
 
 class IlastikModel(Model):
@@ -82,23 +82,24 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
         mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold
         return InMemoryDataAccessor(mask)
 
-# TODO: deprecate
-class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel):
-    model_id = 'ilastik_object_classification_from_pixel_predictions'
+
+class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
+    model_id = 'ilastik_object_classification_from_segmentation'
 
     @staticmethod
     def get_workflow():
-        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
-        return ObjectClassificationWorkflowPrediction
+        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
+        return ObjectClassificationWorkflowBinary
 
-    def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+        assert segmentation_img.is_mask()
         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
-        tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
+        tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
 
         dsi = [
             {
                 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
-                'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
+                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
             }
         ]
 
@@ -113,24 +114,28 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel):
         )
         return InMemoryDataAccessor(data=yxcz), {'success': True}
 
+    def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
+        super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
+        obmap, _ = self.infer(img, mask)
+        return obmap
 
-class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
-    model_id = 'ilastik_object_classification_from_segmentation'
+
+class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel):
+    model_id = 'ilastik_object_classification_from_pixel_predictions'
 
     @staticmethod
     def get_workflow():
-        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
-        return ObjectClassificationWorkflowBinary
+        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
+        return ObjectClassificationWorkflowPrediction
 
-    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
-        assert segmentation_img.is_mask()
+    def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
-        tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
+        tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
 
         dsi = [
             {
                 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
-                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
+                'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
             }
         ]
 
@@ -145,7 +150,26 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
         )
         return InMemoryDataAccessor(data=yxcz), {'success': True}
 
-    def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
-        super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
+
+    def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs):
+        """
+        Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is
+        assigned a class.
+        :param img: input image
+        :param pxmap: map of pixel probabilities
+        :param kwargs:
+            pixel_classification_channel: channel of pxmap used to segment objects
+            pixel_classification_thresold: threshold of pxmap used to segment objects
+        :return:
+        """
+        if not img.shape == pxmap.shape:
+            raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
+        # TODO: check that pxmap is in-range
+        pxch = kwargs.get('pixel_classification_channel', 0)
+        pxtr = kwargs('pixel_classification_threshold', 0.5)
+        mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
+        # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
         obmap, _ = self.infer(img, mask)
-        return obmap
\ No newline at end of file
+        return obmap
+
+
diff --git a/model_server/accessors.py b/model_server/accessors.py
index 4f122fb13d8022255f5866daef4646dd3940a00e..04cb35646e0a0cde44dd7d8bc028fd80b7a5ad6f 100644
--- a/model_server/accessors.py
+++ b/model_server/accessors.py
@@ -34,6 +34,8 @@ class GenericImageDataAccessor(ABC):
     def is_3d(self):
         return True if self.shape_dict['Z'] > 1 else False
 
+    # TODO: implement is_probability
+
     def is_mask(self):
         return is_mask(self._data)