diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py
index 21278461a2d6d949fd54d1abd453f2251f7e39de..8f22c3ad24310af9fb33e17b938a22705a3dbe4e 100644
--- a/extensions/ilastik/models.py
+++ b/extensions/ilastik/models.py
@@ -6,10 +6,10 @@ import vigra
 
 import extensions.ilastik.conf
 from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
-from model_server.models import ImageToImageModel, ParameterExpectedError
+from model_server.models import Model, ParameterExpectedError, SemanticSegmentationModel
 
 
-class IlastikImageToImageModel(ImageToImageModel):
+class IlastikModel(Model):
 
     def __init__(self, params, autoload=True):
         self.project_file = Path(params['project_file'])
@@ -27,7 +27,6 @@ class IlastikImageToImageModel(ImageToImageModel):
         self.shell = None
         super().__init__(autoload, params)
 
-
     def load(self):
         from ilastik import app
         from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo
@@ -51,7 +50,7 @@ class IlastikImageToImageModel(ImageToImageModel):
         return True
 
 
-class IlastikPixelClassifierModel(IlastikImageToImageModel):
+class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
     model_id = 'ilastik_pixel_classification'
     operations = ['segment', ]
 
@@ -78,43 +77,44 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel):
         )
         return InMemoryDataAccessor(data=yxcz), {'success': True}
 
-    def segment(self, input_img, thresh, channel):
-        return InMemoryDataAccessor(
-            self.infer(input_img).data[:, :, channel, :] > thresh
+    def segment(self, img: GenericImageDataAccessor, pixel_class: int, pixel_probability_threshold=0.5):
+        pxmap, _ = self.infer(img)
+        mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold
+        return InMemoryDataAccessor(mask)
+
+# TODO: deprecate
+class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel):
+    model_id = 'ilastik_object_classification_from_pixel_predictions'
+
+    @staticmethod
+    def get_workflow():
+        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
+        return ObjectClassificationWorkflowPrediction
+
+    def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
+        tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
+
+        dsi = [
+            {
+                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
+                'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
+            }
+        ]
+
+        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
+
+        assert (len(obmaps) == 1, 'ilastik generated more than one object map')
+
+        yxcz = np.moveaxis(
+            obmaps[0],
+            [1, 2, 3, 0],
+            [0, 1, 2, 3]
         )
+        return InMemoryDataAccessor(data=yxcz), {'success': True}
+
 
-# class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel):
-#     model_id = 'ilastik_object_classification_from_pixel_predictions'
-#
-#     @staticmethod
-#     def get_workflow():
-#         from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
-#         return ObjectClassificationWorkflowPrediction
-#
-#     def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
-#         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
-#         tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
-#
-#         dsi = [
-#             {
-#                 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
-#                 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
-#             }
-#         ]
-#
-#         obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
-#
-#         assert (len(obmaps) == 1, 'ilastik generated more than one object map')
-#
-#         yxcz = np.moveaxis(
-#             obmaps[0],
-#             [1, 2, 3, 0],
-#             [0, 1, 2, 3]
-#         )
-#         return InMemoryDataAccessor(data=yxcz), {'success': True}
-
-
-class IlastikObjectClassifierFromSegmentationModel(IlastikImageToImageModel):
+class IlastikObjectClassifierFromSegmentationModel(IlastikModel):
     model_id = 'ilastik_object_classification_from_segmentation'
 
     @staticmethod
diff --git a/extensions/ilastik/router.py b/extensions/ilastik/router.py
index 4d9a8b28a0bad08ff4d079e9175193c94d50e047..15a59b283ef6d7746127ba457401fd6d40042b7d 100644
--- a/extensions/ilastik/router.py
+++ b/extensions/ilastik/router.py
@@ -14,7 +14,7 @@ router = APIRouter(
 
 session = Session()
 
-def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
+def load_ilastik_model(model_class: ilm.IlastikModel, project_file: str, duplicate=True) -> dict:
     """
     Load an ilastik model of a given class and project filename.
     :param model_class:
@@ -35,10 +35,6 @@ def load_ilastik_model(model_class: ilm.IlastikImageToImageModel, project_file:
         )
     return {'model_id': result}
 
-# @router.put('/px/load/')
-# def load_px_model(project_file: str, duplicate: bool = True) -> dict:
-#     return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate)
-
 @router.put('/seg/load/')
 def load_px_model(project_file: str, duplicate: bool = True) -> dict:
     return load_ilastik_model(ilm.IlastikPixelClassifierModel, project_file, duplicate=duplicate)
diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py
index 79cbe3b49630468c8c97558e22923cd522b26cf3..f4f400a60ebab7a485e48d8876d4ce428a3dd81a 100644
--- a/extensions/ilastik/tests/test_ilastik.py
+++ b/extensions/ilastik/tests/test_ilastik.py
@@ -113,7 +113,7 @@ class TestIlastikOverApi(TestServerBaseClass):
 
     def test_httpexception_if_incorrect_project_file_loaded(self):
         resp_load = requests.put(
-            self.uri + 'ilastik/px/load/',
+            self.uri + 'ilastik/seg/load/',
             params={'project_file': 'improper.ilp'},
         )
         self.assertEqual(resp_load.status_code, 404)
@@ -121,7 +121,7 @@ class TestIlastikOverApi(TestServerBaseClass):
 
     def test_load_ilastik_pixel_model(self):
         resp_load = requests.put(
-            self.uri + 'ilastik/px/load/',
+            self.uri + 'ilastik/seg/load/',
             params={'project_file': str(ilastik_classifiers['px'])},
         )
         self.assertEqual(resp_load.status_code, 200, resp_load.json())
@@ -137,7 +137,7 @@ class TestIlastikOverApi(TestServerBaseClass):
         resp_list_1st = requests.get(self.uri + 'models').json()
         self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
         resp_load_2nd = requests.put(
-            self.uri + 'ilastik/px/load/',
+            self.uri + 'ilastik/seg/load/',
             params={
                 'project_file': str(ilastik_classifiers['px']),
                 'duplicate': True,
@@ -146,7 +146,7 @@ class TestIlastikOverApi(TestServerBaseClass):
         resp_list_2nd = requests.get(self.uri + 'models').json()
         self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
         resp_load_3rd = requests.put(
-            self.uri + 'ilastik/px/load/',
+            self.uri + 'ilastik/seg/load/',
             params={
                 'project_file': str(ilastik_classifiers['px']),
                 'duplicate': False,
@@ -172,14 +172,14 @@ class TestIlastikOverApi(TestServerBaseClass):
 
         # load models with these paths
         resp1 = requests.put(
-            self.uri + 'ilastik/px/load/',
+            self.uri + 'ilastik/seg/load/',
             params={
                 'project_file': ilp_win,
                 'duplicate': False,
             },
         )
         resp2 = requests.put(
-            self.uri + 'ilastik/px/load/',
+            self.uri + 'ilastik/seg/load/',
             params={
                 'project_file': ilp_posx,
                 'duplicate': False,
diff --git a/model_server/models.py b/model_server/models.py
index 4285b49c4d8a55bb3b77135d9e578546ea851d4c..07fc31dfdf2ea0b85ad5e63fa0daeb8bd15242fd 100644
--- a/model_server/models.py
+++ b/model_server/models.py
@@ -36,7 +36,7 @@ class Model(ABC):
         pass
 
     @abstractmethod
-    def infer(self, img: GenericImageDataAccessor) -> (object, dict): # return json describing inference result
+    def infer(self, *args) -> (object, dict):  # return json describing inference result
         pass
 
     def reload(self):
@@ -52,6 +52,15 @@ class ImageToImageModel(Model):
     def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
         pass
 
+class SemanticSegmentationModel(ImageToImageModel):
+    """
+    Model that exposes a method that returns a binary mask for a given input image and pixel class
+    """
+
+    @abstractmethod
+    def segment(self, img: GenericImageDataAccessor, pixel_class: int, **kwargs) -> (GenericImageDataAccessor, dict):
+        pass
+
 class DummyImageToImageModel(ImageToImageModel):
 
     model_id = 'dummy_make_white_square'
@@ -67,6 +76,7 @@ class DummyImageToImageModel(ImageToImageModel):
         result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255
         return InMemoryDataAccessor(data=result), {'success': True}
 
+
 class Error(Exception):
     pass