From 6942c0446877ca614e68879f58c013a74fb1ebda Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 16 Oct 2023 17:26:32 +0200
Subject: [PATCH] Needed to modify model wrapper to handle patch stacks

---
 ...fer_labels_to_ilastik_object_classifier.py | 51 +++++++++++++++++--
 1 file changed, 46 insertions(+), 5 deletions(-)

diff --git a/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py b/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
index 2fee54bc..3c721d25 100644
--- a/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
+++ b/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
@@ -5,10 +5,51 @@ import json
 import numpy as np
 import pandas as pd
 import uuid
+import vigra
 
 from extensions.chaeo.util import autonumber_new_file
 from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
-from model_server.accessors import generate_file_accessor, write_accessor_data_to_file
+from model_server.accessors import generate_file_accessor, GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
+
+class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
+
+    @staticmethod
+    def make_tczyx(acc):
+        assert acc.chroma == 1
+        tyx = np.moveaxis(
+            acc.data[:, :, 0, :], # YX(C)Z
+            [2, 0, 1],
+            [0, 1, 2]
+        )
+        return np.expand_dims(tyx, (1, 2))
+        # return tyx
+
+    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+        assert segmentation_img.is_mask()
+        assert input_img.chroma == 1
+
+        tagged_input_data = vigra.taggedView(self.make_tczyx(input_img), 'tczyx')
+        tagged_seg_data = vigra.taggedView(self.make_tczyx(segmentation_img), 'tczyx')
+
+        dsi = [
+            {
+                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
+                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
+            }
+        ]
+
+        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
+
+        assert len(obmaps) == 1, 'ilastik generated more than one object map'
+        assert obmaps[0].shape == (input_img.nz, 1, input_img.hw[0], input_img.hw[1], 1) # z(1)yx(1)
+
+        yxcz = np.moveaxis(
+            obmaps[0][:, :, :, :, 0],
+            [2, 3, 1, 0],
+            [0, 1, 2, 3]
+        )
+        assert yxcz.shape == input_img.shape
+        return InMemoryDataAccessor(data=yxcz), {'success': True}
 
 def get_dataset_info(h5, lane=0):
     lns = f'{lane:04d}'
@@ -154,8 +195,8 @@ if __name__ == '__main__':
     train_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_train_mask.tif')
 
     new_ilp = root / 'exp0014/test_obj_from_seg.ilp'
-    mod = IlastikObjectClassifierFromSegmentationModel({'project_file': new_ilp})
+    mod = PatchStackObjectClassifier({'project_file': new_ilp})
 
-    result = mod.infer(train_zstack_raw, train_zstack_mask)
-    write_accessor_data_to_file(where_patch_stack / 'result.tif', result)
-    print(mod.project_file_abspath)
\ No newline at end of file
+    result_acc, _ = mod.infer(train_zstack_raw, train_zstack_mask)
+    write_accessor_data_to_file(where_patch_stack / 'result.tif', result_acc)
+    print(where_patch_stack / 'result.tif')
\ No newline at end of file
-- 
GitLab