From 6debd9c3708776b76d41910319b005377ca5fe8c Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 26 Oct 2023 15:51:44 +0200
Subject: [PATCH] Use patch stack accessor in a batch script

---
 extensions/chaeo/accessors.py                 |  7 +++-
 ...fer_labels_to_ilastik_object_classifier.py | 36 ++++++++-----------
 extensions/chaeo/tests/test_accessors.py      | 10 +++---
 3 files changed, 26 insertions(+), 27 deletions(-)

diff --git a/extensions/chaeo/accessors.py b/extensions/chaeo/accessors.py
index 32fb7020..d1e9089c 100644
--- a/extensions/chaeo/accessors.py
+++ b/extensions/chaeo/accessors.py
@@ -1,5 +1,5 @@
 import numpy as np
-from model_server.accessors import InMemoryDataAccessor
+from model_server.accessors import generate_file_accessor, InMemoryDataAccessor
 
 class MonoPatchStack(InMemoryDataAccessor):
 
@@ -44,6 +44,11 @@ class MonoPatchStack(InMemoryDataAccessor):
         return [self.data[:, :, 0, zi] for zi in range(0, n)]
 
 
+class MonoPatchStackFromFile(MonoPatchStack):
+    def __init__(self, fpath):
+        super().__init__(generate_file_accessor(fpath).data[:, :, 0, :])
+
+
 class Error(Exception):
     pass
 
diff --git a/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py b/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
index 49de008f..9e07ec2b 100644
--- a/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
+++ b/extensions/chaeo/examples/transfer_labels_to_ilastik_object_classifier.py
@@ -8,6 +8,7 @@ import skimage
 import uuid
 import vigra
 
+from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile
 from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
 from model_server.accessors import generate_file_accessor, GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 
@@ -17,22 +18,13 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
     as time-series images where each frame contains only one object.
     """
 
-    @staticmethod
-    def make_tczyx(acc: GenericImageDataAccessor):
-        assert acc.chroma == 1
-        tyx = np.moveaxis(
-            acc.data[:, :, 0, :], # YX(C)Z
-            [2, 0, 1],
-            [0, 1, 2]
-        )
-        return np.expand_dims(tyx, (1, 2))
 
-    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
-        assert segmentation_img.is_mask()
-        assert input_img.chroma == 1
+    def infer(self, input_acc: MonoPatchStack, segmentation_acc: MonoPatchStack) -> (np.ndarray, dict):
+        assert segmentation_acc.is_mask()
+        assert input_acc.chroma == 1
 
-        tagged_input_data = vigra.taggedView(self.make_tczyx(input_img), 'tczyx')
-        tagged_seg_data = vigra.taggedView(self.make_tczyx(segmentation_img), 'tczyx')
+        tagged_input_data = vigra.taggedView(input_acc.make_tczyx(), 'tczyx')
+        tagged_seg_data = vigra.taggedView(segmentation_acc.make_tczyx(), 'tczyx')
 
         dsi = [
             {
@@ -46,14 +38,14 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
         assert len(obmaps) == 1, 'ilastik generated more than one object map'
 
         # for some reason ilastik scrambles these axes to Z(1)YX(1)
-        assert obmaps[0].shape == (input_img.nz, 1, input_img.hw[0], input_img.hw[1], 1)
+        assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
         yxcz = np.moveaxis(
             obmaps[0][:, :, :, :, 0],
             [2, 3, 1, 0],
             [0, 1, 2, 3]
         )
 
-        assert yxcz.shape == input_img.shape
+        assert yxcz.shape == input_acc.shape
         return InMemoryDataAccessor(data=yxcz), {'success': True}
 
 def get_dataset_info(h5: h5py.File, lane : int = 0):
@@ -236,16 +228,16 @@ if __name__ == '__main__':
 
     def infer_and_compare_training_set(ilp, suffix):
         # infer object labels from the same data used to train the classifier
-        train_zstack_raw = generate_file_accessor(where_patch_stack / 'zstack_train_raw.tif')
-        train_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_train_mask.tif')
-        train_truth_labels = generate_file_accessor(where_patch_stack / f'zstack_train_label.tif')
+        train_zstack_raw = MonoPatchStackFromFile(where_patch_stack / 'zstack_train_raw.tif')
+        train_zstack_mask = MonoPatchStackFromFile(where_patch_stack / 'zstack_train_mask.tif')
+        train_truth_labels = MonoPatchStackFromFile(where_patch_stack / f'zstack_train_label.tif')
         infer_and_compare(ilp, 'train', suffix, train_zstack_raw, train_zstack_mask, train_truth_labels)
 
     def infer_and_compare_test_set(ilp, suffix):
         # infer object labels from test dataset
-        test_zstack_raw = generate_file_accessor(where_patch_stack / 'zstack_test_raw.tif')
-        test_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_test_mask.tif')
-        test_truth_labels = generate_file_accessor(where_patch_stack / f'zstack_test_label.tif')
+        test_zstack_raw = MonoPatchStackFromFile(where_patch_stack / 'zstack_test_raw.tif')
+        test_zstack_mask = MonoPatchStackFromFile(where_patch_stack / 'zstack_test_mask.tif')
+        test_truth_labels = MonoPatchStackFromFile(where_patch_stack / f'zstack_test_label.tif')
         infer_and_compare(ilp, 'test', suffix, test_zstack_raw, test_zstack_mask, test_truth_labels)
 
     def infer_and_compare(ilp, prefix, suffix, raw, mask, labels):
diff --git a/extensions/chaeo/tests/test_accessors.py b/extensions/chaeo/tests/test_accessors.py
index 465f15e0..842b8fbc 100644
--- a/extensions/chaeo/tests/test_accessors.py
+++ b/extensions/chaeo/tests/test_accessors.py
@@ -15,13 +15,15 @@ class TestCziImageFileAccess(unittest.TestCase):
         h = 512
         n = 4
         acc = MonoPatchStack(np.random.rand(h, w, n))
-        assert acc.count == n
-        assert acc.hw == (h, w)
+        self.assertEqual(acc.count, n)
+        self.assertEqual(acc.hw, (h, w))
+        self.assertEqual(acc.make_tczyx().shape, (n, 1, 1, h, w))
 
     def test_make_patch_stack_from_3d_array(self):
         w = 256
         h = 512
         n = 4
         acc = MonoPatchStack([np.random.rand(h, w) for _ in range(0, 4)])
-        assert acc.count == n
-        assert acc.hw == (h, w)
\ No newline at end of file
+        self.assertEqual(acc.count, n)
+        self.assertEqual(acc.hw, (h, w))
+        self.assertEqual(acc.make_tczyx().shape, (n, 1, 1, h, w))
\ No newline at end of file
-- 
GitLab