diff --git a/model_server/base/api.py b/model_server/base/api.py
index ba047efe7a1d2a0e487fe8a6582fcf58be9ac670..f1f658135e7729822f86f11a99e50392f9b35242 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -69,16 +69,18 @@ def list_session_log() -> list:
 def list_active_models():
     return session.describe_loaded_models()
 
+
 class BinaryThresholdSegmentationParams(BaseModel):
-    channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
     tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation')
 
+
 @app.put('/models/seg/binary_threshold/load/')
 def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
     result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p)
     session.log_info(f'Loaded binary threshold segmentation model {result}')
     return {'model_id': result}
 
+
 @app.get('/accessors')
 def list_accessors():
     return session.list_accessors()
diff --git a/model_server/base/models.py b/model_server/base/models.py
index 346bfaa4d2a1df6c438b0d2689bbf17c09c8c570..26e8af624eb1a98ae2067785b06a55b40dc45621 100644
--- a/model_server/base/models.py
+++ b/model_server/base/models.py
@@ -134,18 +134,13 @@ class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
 
     def __init__(self, params=None):
         self.tr = params['tr']
-        self.channel = params['channel']
         self.loaded = True
 
-    def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
-        if self.channel:
-            acc = img.get_mono(self.channel)
-        else:
-            acc = img
-        return acc.get_mono(self.channel).apply(lambda x: x > self.tr)
+    def infer(self, acc: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
+        return acc.apply(lambda x: x > self.tr)
 
-    def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
-        return self.infer(img, **kwargs)
+    def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
+        return self.infer(acc, **kwargs)
 
     def load(self):
         pass
diff --git a/tests/base/test_model.py b/tests/base/test_model.py
index 38e51c0b12650ffdbf70a86f0e7b6b5f54102fa3..8340111c5b45e764b76927c1534a64486d9cc2b6 100644
--- a/tests/base/test_model.py
+++ b/tests/base/test_model.py
@@ -56,7 +56,7 @@ class TestCziImageFileAccess(unittest.TestCase):
         return img, mask
 
     def test_binary_segmentation(self):
-        model = BinaryThresholdSegmentationModel({'tr': 3e4, 'channel': 0})
+        model = BinaryThresholdSegmentationModel({'tr': 3e4})
         res = model.label_pixel_class(self.cf)
         self.assertTrue(res.is_mask())