From 2a542f58588b437664720eb2c65946b69820e858 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 26 Apr 2024 15:16:33 +0200
Subject: [PATCH] ilastik now allows multichannel inputs

---
 model_server/extensions/ilastik/models.py    | 24 ++++++++++++++++----
 model_server/extensions/ilastik/workflows.py | 10 ++++++--
 2 files changed, 27 insertions(+), 7 deletions(-)

diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 2dea1332..2c92c026 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -104,8 +104,15 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
         return [l.decode() for l in h5['PixelClassification/LabelNames'][()]]
 
     def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict):
-        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
-            raise IlastikInputShapeError()
+        if self.model_chroma != input_img.chroma:
+            raise IlastikInputShapeError(
+                f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
+            )
+        if self.model_3d != input_img.is_3d():
+            if self.model_3d:
+                raise IlastikInputShapeError(f'Model is 3D but input image is 2D')
+            else:
+                raise IlastikInputShapeError(f'Model is 2D but input image is 3D')
 
         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
         dsi = [
@@ -164,7 +171,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
     def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
         if self.model_chroma != input_img.chroma:
             raise IlastikInputShapeError(
-                f'Model {self} expects {self.model_chroma} input channels but received only {input_img.chroma}'
+                f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
             )
         if self.model_3d != input_img.is_3d():
             if self.model_3d:
@@ -229,8 +236,15 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
         return ObjectClassificationWorkflowPrediction
 
     def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
-        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
-            raise IlastikInputShapeError()
+        if self.model_chroma != input_img.chroma:
+            raise IlastikInputShapeError(
+                f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
+            )
+        if self.model_3d != input_img.is_3d():
+            if self.model_3d:
+                raise IlastikInputShapeError(f'Model is 3D but input image is 2D')
+            else:
+                raise IlastikInputShapeError(f'Model is 2D but input image is 3D')
 
         if isinstance(input_img, PatchStack):
             assert isinstance(pxmap_img, PatchStack)
diff --git a/model_server/extensions/ilastik/workflows.py b/model_server/extensions/ilastik/workflows.py
index c5b1575f..6f913f65 100644
--- a/model_server/extensions/ilastik/workflows.py
+++ b/model_server/extensions/ilastik/workflows.py
@@ -26,6 +26,7 @@ def infer_px_then_ob_model(
         px_model: IlastikPixelClassifierModel,
         ob_model: IlastikObjectClassifierFromPixelPredictionsModel,
         where_output: Path,
+        channel: int = None,
         **kwargs
 ) -> WorkflowRunRecord:
     """
@@ -35,6 +36,7 @@ def infer_px_then_ob_model(
     :param px_model: model instance for pixel classification
     :param ob_model: model instance for object classification
     :param where_output: Path object that references output image directory
+    :param channel: input image channel to pass to pixel classification, or all channels if None
     :param kwargs: variable-length keyword arguments
     :return:
     """
@@ -42,8 +44,12 @@ def infer_px_then_ob_model(
     assert isinstance(ob_model, IlastikObjectClassifierFromPixelPredictionsModel)
 
     ti = Timer()
-    ch = kwargs.get('channel')
-    img = generate_file_accessor(fpi).get_one_channel_data(ch, mip=kwargs.get('mip', False))
+    raw_acc = generate_file_accessor(fpi)
+    if channel is not None:
+        channels = [channel]
+    else:
+        channels = range(0, raw_acc.chroma)
+    img = raw_acc.get_channels(channels, mip=kwargs.get('mip', False))
     ti.click('file_input')
 
     px_map, _ = px_model.infer(img)
-- 
GitLab