From a6e1328761e918271a8b4fbfa521efb68a9ccfce Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 30 Oct 2024 17:23:26 +0100 Subject: [PATCH] Ilastik pixel classifier calls batch mode for performance reasons; removed handling of ilastik.applets.featureSelection.opFeatureSelection.FeatureSelectionConstraintError; unsure if this remains a risk --- model_server/extensions/ilastik/models.py | 21 ++++++++++----------- tests/test_ilastik/test_ilastik.py | 2 +- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 2c895fef..35e5d8ec 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -148,17 +148,16 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): """ Iterative over a patch stack, call inference separately on each cropped patch """ - from ilastik.applets.featureSelection.opFeatureSelection import FeatureSelectionConstraintError - - nc = len(self.labels) - data = np.zeros((img.count, *img.hw, nc, img.nz), dtype=float) # interpret as PYXCZ - for i in range(0, img.count): - sl = img.get_slice_at(i) - try: - data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True)).data - except FeatureSelectionConstraintError: # occurs occasionally on small patches - continue - return PatchStack(data) + dsi = [ + { + 'Raw Data': self.PreloadedArrayDatasetInfo( + preloaded_array=vigra.taggedView(patch, 'yxcz')) + + } for patch in img.get_list() + ] + pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] + yxcz = [np.moveaxis(pm, [1, 2, 3, 0], [0, 1, 2, 3]) for pm in pxmaps] + return PatchStack(yxcz) def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs): pxmap = self.infer(img) diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index e6b36bb9..8068f66f 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -144,7 +144,7 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(mask.count, acc.count) pxmap = self.model.infer_patch_stack(acc) - self.assertEqual(pxmap.dtype, float) + self.assertEqual(pxmap.dtype, 'float32') self.assertEqual(pxmap.chroma, len(self.model.labels)) self.assertEqual(pxmap.hw, acc.hw) self.assertEqual(pxmap.nz, acc.nz) -- GitLab