From 98be6bb7dbe8e1a1b6b5e2a41e4e9812bdc60d59 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 25 Mar 2024 13:51:25 +0100
Subject: [PATCH] Improve normalization when generating instance segmentation
 model

---
 model_server/extensions/ilastik/models.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index d3cc0d93..7229b43e 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -252,7 +252,12 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
             def label_instance_class(
                     self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
             ) -> GenericImageDataAccessor:
-                return super().label_instance_class(img, mask, pixel_classification_channel=px_ch)
+                if mask.dtype == 'bool':
+                    norm_mask = 1.0 * mask.data
+                else:
+                    norm_mask = mask.data / np.iinfo(mask.dtype).max
+                norm_mask_acc = InMemoryDataAccessor(norm_mask.astype('float32'))
+                return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch)
         return _Mod(params={'project_file': self.project_file})
 
 
-- 
GitLab