From 9331b28a3396a0bb900d7fbf875cfe4c1ccdbaa7 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 5 Aug 2024 11:43:35 +0200
Subject: [PATCH] Corrected Gaussian smoothing in pixel map

---
 model_server/extensions/ilastik/models.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 5e7a0a75..fe2a7052 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -107,7 +107,7 @@ class IlastikPixelClassifierParams(IlastikParams):
 class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
     model_id = 'ilastik_pixel_classification'
     operations = ['segment', ]
-    
+
     def __init__(self, params: IlastikPixelClassifierParams, **kwargs):
         super(IlastikPixelClassifierModel, self).__init__(params, **kwargs)
 
@@ -169,7 +169,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
         pxmap, _ = self.infer(img)
         sig = self.params['px_smoothing']
         if sig > 0.0:
-            proc = smooth(img.data, sig)
+            proc = smooth(pxmap.data, sig)
         else:
             proc = pxmap.data
         mask = proc[:, :, self.params['px_class'], :] > self.params['px_prob_threshold']
@@ -317,9 +317,11 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
         """
         if not img.shape == pxmap.shape:
             raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
+        if not pxmap.data.min() >= 0.0 and pxmap.data.max() <= 1.0:
+            raise InvalidInputImageError('Pixel probability values must be between 0.0 and 1.0')
         pxch = kwargs.get('pixel_classification_channel', 0)
-        pxtr = kwargs.get('pixel_classification_threshold', 0.5)
-        mask = pxmap.get_mono(pxch).apply(lambda x: x > pxtr)
+        pxtr = kwargs('pixel_classification_threshold', 0.5)
+        mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
         obmap, _ = self.infer(img, mask)
         return obmap
 
-- 
GitLab