From efce29dcbd08ef5b2d4b2c091dfc6dc2bac1ac6d Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 25 Oct 2024 12:49:44 +0200
Subject: [PATCH] Tests pass

---
 model_server/extensions/ilastik/models.py  | 6 +++++-
 tests/test_ilastik/test_ilastik.py         | 2 +-
 tests/test_ilastik/test_roiset_workflow.py | 8 ++++++--
 3 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 5e45bab4..e0be3cea 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -157,7 +157,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
 
     def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs):
         pxmap = self.infer(img)
-        mask = pxmap.get_mono(self.params['px_class']).apply(lambda x: x > self.params['px_prob_threshold'])
+        mask = pxmap.get_mono(
+            self.params['px_class']
+        ).apply(
+            lambda x: x > self.params['px_prob_threshold']
+        )
         return mask
 
 
diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py
index 4eedf981..9bf3648c 100644
--- a/tests/test_ilastik/test_ilastik.py
+++ b/tests/test_ilastik/test_ilastik.py
@@ -344,7 +344,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase):
         mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
             {'project_file': self.pa_ob_pxmap_classifier.__str__()}
         )
-        obmap = mod.infer(img, pxmap)[0]
+        obmap = mod.infer(img, pxmap)
         self.assertEqual(obmap.hw, img.hw)
         self.assertEqual(obmap.nz, img.nz)
 
diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py
index 4adf24bb..0dc2a9ed 100644
--- a/tests/test_ilastik/test_roiset_workflow.py
+++ b/tests/test_ilastik/test_roiset_workflow.py
@@ -72,8 +72,12 @@ class BaseTestRoiSetMonoProducts(object):
                 'name': 'ilastik_px_mod',
                 'project_file': fp_px,
                 'model': ilm.IlastikPixelClassifierModel(
-                    {'project_file': fp_px},
-                )
+                    {
+                        'project_file': fp_px,
+                        'px_class': 0,
+                        'px_prob_threshold': 0.5,
+                    },
+                ),
             },
             'object_classifier': {
                 'name': 'ilastik_ob_mod',
-- 
GitLab