From 0a3039faad8e2427f9da0f3c8d65776b96b2f405 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 19 Feb 2024 13:05:19 +0100
Subject: [PATCH] Move patch stack wrapper to ilastik extension

---
 .../chaeo/examples/batch_obj_cla.py           |  2 +-
 ...fer_labels_to_ilastik_object_classifier.py |  3 +-
 model_server/extensions/chaeo/models.py       | 40 +------------------
 model_server/extensions/ilastik/models.py     | 36 +++++++++++++++++
 4 files changed, 40 insertions(+), 41 deletions(-)

diff --git a/model_server/extensions/chaeo/examples/batch_obj_cla.py b/model_server/extensions/chaeo/examples/batch_obj_cla.py
index 7471034f..eef5e703 100644
--- a/model_server/extensions/chaeo/examples/batch_obj_cla.py
+++ b/model_server/extensions/chaeo/examples/batch_obj_cla.py
@@ -2,7 +2,7 @@ from pathlib import Path
 
 from model_server.conf.testing import output_path
 from model_server.base.util import autonumber_new_directory, get_matching_files, loop_workflow
-from model_server.extensions.chaeo.models import PatchStackObjectClassifier
+from extensions.ilastik.models import PatchStackObjectClassifier
 from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack
 from model_server.extensions.ilastik.models import IlastikPixelClassifierModel
 
diff --git a/model_server/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py b/model_server/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
index ae306745..5d9625b5 100644
--- a/model_server/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
+++ b/model_server/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
@@ -4,7 +4,8 @@ import pandas as pd
 import skimage
 
 from model_server.extensions.chaeo.accessors import MonoPatchStackFromFile
-from model_server.extensions.chaeo.models import generate_ilastik_object_classifier, PatchStackObjectClassifier
+from model_server.extensions.chaeo.models import generate_ilastik_object_classifier
+from extensions.ilastik.models import PatchStackObjectClassifier
 from model_server.base.accessors import GenericImageDataAccessor, write_accessor_data_to_file
 
 
diff --git a/model_server/extensions/chaeo/models.py b/model_server/extensions/chaeo/models.py
index 7f8c08bd..1a969db3 100644
--- a/model_server/extensions/chaeo/models.py
+++ b/model_server/extensions/chaeo/models.py
@@ -4,48 +4,10 @@ import shutil
 import h5py
 import numpy as np
 import skimage
-import vigra
 
-from model_server.extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile
-from model_server.extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
+from model_server.extensions.chaeo.accessors import MonoPatchStackFromFile
 
 
-class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
-    """
-    Wrap ilastik object classification for inputs comprising raw image and binary segmentation masks, both represented
-    as time-series images where each frame contains only one object.
-    """
-
-    def infer(self, input_acc: MonoPatchStack, segmentation_acc: MonoPatchStack) -> (np.ndarray, dict):
-        assert segmentation_acc.is_mask()
-        assert input_acc.chroma == 1
-
-        tagged_input_data = vigra.taggedView(input_acc.make_tczyx(), 'tczyx')
-        tagged_seg_data = vigra.taggedView(segmentation_acc.make_tczyx(), 'tczyx')
-
-        dsi = [
-            {
-                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
-                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
-            }
-        ]
-
-        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)  # [z x h x w x n]
-
-        assert len(obmaps) == 1, 'ilastik generated more than one object map'
-
-        # for some reason ilastik scrambles these axes to Z(1)YX(1)
-        assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
-        yxz = np.moveaxis(
-            obmaps[0][:, 0, :, :, 0],
-            [1, 2, 0],
-            [0, 1, 2]
-        )
-
-        assert yxz.shape[0:2] == input_acc.hw
-        assert yxz.shape[2] == input_acc.nz
-        return MonoPatchStack(data=yxz), {'success': True}
-
 
 def generate_ilastik_object_classifier(
         template_ilp: Path,
diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 1cdd5477..2cd56349 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -5,6 +5,7 @@ import numpy as np
 import vigra
 
 import model_server.extensions.ilastik.conf
+from extensions.chaeo.accessors import PatchStack
 from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
 from model_server.base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
 
@@ -180,3 +181,38 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
         return obmap
 
 
+class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
+    """
+    Wrap ilastik object classification for inputs comprising single-object series of raw images and binary
+    segmentation masks.
+    """
+
+    def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict):
+        assert segmentation_acc.is_mask()
+        assert input_acc.chroma == 1
+
+        tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx')
+        tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx')
+
+        dsi = [
+            {
+                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
+                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
+            }
+        ]
+
+        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)  # [z x h x w x n]
+
+        assert len(obmaps) == 1, 'ilastik generated more than one object map'
+
+        # for some reason ilastik scrambles these axes to Z(1)YX(1)
+        assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
+        yxz = np.moveaxis(
+            obmaps[0][:, 0, :, :, 0],
+            [1, 2, 0],
+            [0, 1, 2]
+        )
+
+        assert yxz.shape[0:2] == input_acc.hw
+        assert yxz.shape[2] == input_acc.nz
+        return PatchStack(data=yxz), {'success': True}
-- 
GitLab