diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 5e45bab4d6e9077c1799a957f041edcf60f1baed..e0be3ceae959cd459af7c5626f9181b0b172b957 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -157,7 +157,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs): pxmap = self.infer(img) - mask = pxmap.get_mono(self.params['px_class']).apply(lambda x: x > self.params['px_prob_threshold']) + mask = pxmap.get_mono( + self.params['px_class'] + ).apply( + lambda x: x > self.params['px_prob_threshold'] + ) return mask diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index 4eedf981452b7771d9e1b6aa5c61f0fa72368633..9bf3648cfa5184e745fd03c3f13985324ef190d6 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -344,7 +344,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase): mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel( {'project_file': self.pa_ob_pxmap_classifier.__str__()} ) - obmap = mod.infer(img, pxmap)[0] + obmap = mod.infer(img, pxmap) self.assertEqual(obmap.hw, img.hw) self.assertEqual(obmap.nz, img.nz) diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py index 4adf24bb25a13e9a72bbad1bd0fbf9f0983b88ca..0dc2a9ed28cfcf6c40f61e26ddaf085cd0a37004 100644 --- a/tests/test_ilastik/test_roiset_workflow.py +++ b/tests/test_ilastik/test_roiset_workflow.py @@ -72,8 +72,12 @@ class BaseTestRoiSetMonoProducts(object): 'name': 'ilastik_px_mod', 'project_file': fp_px, 'model': ilm.IlastikPixelClassifierModel( - {'project_file': fp_px}, - ) + { + 'project_file': fp_px, + 'px_class': 0, + 'px_prob_threshold': 0.5, + }, + ), }, 'object_classifier': { 'name': 'ilastik_ob_mod',