From a6e1328761e918271a8b4fbfa521efb68a9ccfce Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 30 Oct 2024 17:23:26 +0100
Subject: [PATCH] Ilastik pixel classifier calls batch mode for performance
 reasons; removed handling of
 ilastik.applets.featureSelection.opFeatureSelection.FeatureSelectionConstraintError;
 unsure if this remains a risk

---
 model_server/extensions/ilastik/models.py | 21 ++++++++++-----------
 tests/test_ilastik/test_ilastik.py        |  2 +-
 2 files changed, 11 insertions(+), 12 deletions(-)

diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 2c895fef..35e5d8ec 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -148,17 +148,16 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
         """
         Iterative over a patch stack, call inference separately on each cropped patch
         """
-        from ilastik.applets.featureSelection.opFeatureSelection import FeatureSelectionConstraintError
-
-        nc = len(self.labels)
-        data = np.zeros((img.count, *img.hw, nc, img.nz), dtype=float)  # interpret as PYXCZ
-        for i in range(0, img.count):
-            sl = img.get_slice_at(i)
-            try:
-                data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True)).data
-            except FeatureSelectionConstraintError: # occurs occasionally on small patches
-                continue
-        return PatchStack(data)
+        dsi = [
+            {
+                'Raw Data': self.PreloadedArrayDatasetInfo(
+                    preloaded_array=vigra.taggedView(patch, 'yxcz'))
+
+            } for patch in img.get_list()
+        ]
+        pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)  # [z x h x w x n]
+        yxcz = [np.moveaxis(pm, [1, 2, 3, 0], [0, 1, 2, 3]) for pm in pxmaps]
+        return PatchStack(yxcz)
 
     def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs):
         pxmap = self.infer(img)
diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py
index e6b36bb9..8068f66f 100644
--- a/tests/test_ilastik/test_ilastik.py
+++ b/tests/test_ilastik/test_ilastik.py
@@ -144,7 +144,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.assertEqual(mask.count, acc.count)
 
         pxmap = self.model.infer_patch_stack(acc)
-        self.assertEqual(pxmap.dtype, float)
+        self.assertEqual(pxmap.dtype, 'float32')
         self.assertEqual(pxmap.chroma, len(self.model.labels))
         self.assertEqual(pxmap.hw, acc.hw)
         self.assertEqual(pxmap.nz, acc.nz)
-- 
GitLab