diff --git a/model_server/base/models.py b/model_server/base/models.py
index 26e8af624eb1a98ae2067785b06a55b40dc45621..9aadc1c9a954eaf766c3bdfe6f9457db92c000ad 100644
--- a/model_server/base/models.py
+++ b/model_server/base/models.py
@@ -35,7 +35,7 @@ class Model(ABC):
         pass
 
     @abstractmethod
-    def infer(self, *args) -> (object, dict):
+    def infer(self, *args) -> object:
         """
         Abstract method that carries out the computationally intensive step of running data through a model
         :param args:
@@ -58,7 +58,7 @@ class ImageToImageModel(Model):
     """
 
     @abstractmethod
-    def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
+    def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor:
         pass
 
 
@@ -86,6 +86,25 @@ class SemanticSegmentationModel(ImageToImageModel):
         return PatchStack(data)
 
 
+class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
+    """
+    Trivial but functional model that labels all pixels above an intensity threshold as class 1
+    """
+
+    def __init__(self, params=None):
+        self.tr = params['tr']
+        self.loaded = self.load()
+
+    def infer(self, acc: GenericImageDataAccessor) -> GenericImageDataAccessor:
+        return acc.apply(lambda x: x > self.tr)
+
+    def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
+        return self.infer(acc, **kwargs)
+
+    def load(self):
+        return True
+
+
 class InstanceSegmentationModel(ImageToImageModel):
     """
     Base model that exposes a method that returns an instance classification map for a given input image and mask
@@ -127,24 +146,25 @@ class InstanceSegmentationModel(ImageToImageModel):
         return PatchStack(data)
 
 
-class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
+class PermissiveInstanceSegmentationModel(InstanceSegmentationModel):
     """
-    Trivial but functional model that labels all pixels above an intensity threshold as class 1
+    Trivial but functional model that labels all objects as class 1
     """
 
     def __init__(self, params=None):
-        self.tr = params['tr']
-        self.loaded = True
-
-    def infer(self, acc: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
-        return acc.apply(lambda x: x > self.tr)
-
-    def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
-        return self.infer(acc, **kwargs)
+        self.loaded = self.load()
 
     def load(self):
-        pass
+        return True
+
+    def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor:
+        return mask.apply(lambda x: (1 * (x > 0)).astype(acc.dtype))
 
+    def label_instance_class(
+            self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
+    ) -> GenericImageDataAccessor:
+        super().label_instance_class(img, mask, **kwargs)
+        return self.infer(img, mask)
 
 class Error(Exception):
     pass
diff --git a/tests/base/test_model.py b/tests/base/test_model.py
index 8340111c5b45e764b76927c1534a64486d9cc2b6..62d516bbb7b239df5a61f4af2db65f7dbe779c34 100644
--- a/tests/base/test_model.py
+++ b/tests/base/test_model.py
@@ -1,9 +1,11 @@
 import unittest
 
+import numpy as np
+
 import model_server.conf.testing as conf
 from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceSegmentationModel
 from model_server.base.accessors import CziImageFileAccessor
-from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel
+from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel, PermissiveInstanceSegmentationModel
 
 czifile = conf.meta['image_files']['czifile']
 
@@ -64,3 +66,11 @@ class TestCziImageFileAccess(unittest.TestCase):
         img, mask = self.test_dummy_pixel_segmentation()
         model = DummyInstanceSegmentationModel()
         obmap = model.label_instance_class(img, mask)
+        self.assertTrue(all(obmap.unique()[0] == [0, 1]))
+        self.assertTrue(all(obmap.unique()[1] > 0))
+
+    def test_permissive_instance_segmentation(self):
+        img, mask = self.test_dummy_pixel_segmentation()
+        model = PermissiveInstanceSegmentationModel()
+        obmap = model.label_instance_class(img, mask)
+        self.assertTrue(np.all(mask.data == 255 * obmap.data))