diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py
index f3350ab7aca311479c1dba818a82318893ac661c..9b55d6d0eeec83f1bab962f18d81fd275eee5b0a 100644
--- a/extensions/ilastik/models.py
+++ b/extensions/ilastik/models.py
@@ -77,7 +77,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
         )
         return InMemoryDataAccessor(data=yxcz), {'success': True}
 
-    def segment(self, img: GenericImageDataAccessor, pixel_class: int, pixel_probability_threshold=0.5):
+    def label_pixel_class(self, img: GenericImageDataAccessor, pixel_class: int = 0, pixel_probability_threshold=0.5):
         pxmap, _ = self.infer(img)
         mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold
         return InMemoryDataAccessor(mask)
@@ -145,7 +145,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
         )
         return InMemoryDataAccessor(data=yxcz), {'success': True}
 
-    def label_instance_classes(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
-        super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_classes(img, mask, **kwargs)
+    def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
+        super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
         obmap, _ = self.infer(img, mask)
         return obmap
\ No newline at end of file
diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py
index f4f400a60ebab7a485e48d8876d4ce428a3dd81a..ed74c43e8293a743876803315f3adfd964dfaa6b 100644
--- a/extensions/ilastik/tests/test_ilastik.py
+++ b/extensions/ilastik/tests/test_ilastik.py
@@ -35,7 +35,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))
 
         with self.assertRaises(AttributeError):
-            pxmap, _ = model.infer(input_img)
+            mask = model.label_pixel_class(input_img)
 
 
     def test_run_pixel_classifier_on_random_data(self):
@@ -47,8 +47,8 @@ class TestIlastikPixelClassification(unittest.TestCase):
 
         input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1))
 
-        pxmap, _ = model.infer(input_img)
-        self.assertEqual(pxmap.shape, (h, w, 2, 1))
+        mask = model.label_pixel_class(input_img)
+        self.assertEqual(mask.shape, (h, w, 1, 1))
 
 
     def test_run_pixel_classifier(self):
@@ -66,20 +66,21 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.assertEqual(mono_image.shape_dict['C'], 1)
         self.assertEqual(mono_image.shape_dict['Z'], 1)
 
-        pxmap, _ = model.infer(mono_image)
+        mask = model.label_pixel_class(mono_image)
 
-        self.assertEqual(pxmap.shape[0:2], cf.shape[0:2])
-        self.assertEqual(pxmap.shape_dict['C'], 2)
-        self.assertEqual(pxmap.shape_dict['Z'], 1)
+        self.assertTrue(mask.is_mask())
+        self.assertEqual(mask.shape[0:2], cf.shape[0:2])
+        self.assertEqual(mask.shape_dict['C'], 1)
+        self.assertEqual(mask.shape_dict['Z'], 1)
         self.assertTrue(
             write_accessor_data_to_file(
                 output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
-                pxmap
+                mask
             )
         )
 
         self.mono_image = mono_image
-        self.pxmap = pxmap
+        self.mask = mask
 
     def test_run_object_classifier(self):
         self.test_run_pixel_classifier()
@@ -87,7 +88,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
             {'project_file': ilastik_classifiers['pxmap_to_obj']}
         )
-        objmap, _ = model.infer(self.mono_image, self.pxmap)
+        objmap, _ = model.infer(self.mono_image, self.mask)
 
         self.assertTrue(
             write_accessor_data_to_file(
diff --git a/model_server/models.py b/model_server/models.py
index 087a9975fb908498896c6a17d36ecd9af137ba74..4f07da5730a3d06d93f6a310bafead67f740a202 100644
--- a/model_server/models.py
+++ b/model_server/models.py
@@ -36,7 +36,12 @@ class Model(ABC):
         pass
 
     @abstractmethod
-    def infer(self, *args) -> (object, dict):  # return json describing inference result
+    def infer(self, *args) -> (object, dict):
+        """
+        Abstract method that carries out the computationally intensive step of running data through a model
+        :param args:
+        :return:
+        """
         pass
 
     def reload(self):
@@ -55,18 +60,30 @@ class ImageToImageModel(Model):
 
 class SemanticSegmentationModel(ImageToImageModel):
     """
-    Model that exposes a method that returns a binary mask for a given input image and pixel class
+    Base model that exposes a method that returns a binary mask for a given input image and pixel class
     """
 
     @abstractmethod
-    def segment(self, img: GenericImageDataAccessor, pixel_class: int, **kwargs) -> (GenericImageDataAccessor, dict):
+    def label_pixel_class(
+            self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
+        """
+        Given an image, return an image of the same shape where each pixel is assigned to one or more integer classes
+        """
         pass
 
 
 class InstanceSegmentationModel(ImageToImageModel):
+    """
+    Base model that exposes a method that returns an instance classification map for a given input image and mask
+    """
 
     @abstractmethod
-    def label_instance_classes(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
+    def label_instance_class(
+            self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
+    ) -> GenericImageDataAccessor:
+        """
+        Given an image and a mask of the same size, return a map where each connected object is assigned a class
+        """
         if not mask.is_mask():
             raise InvalidInputImageError('Expecting a binary mask')
         if not img.shape == mask.shape: