From 3cc6ffd038960e1c82c90c48ae55133a07b94313 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 7 Sep 2023 10:56:31 +0200
Subject: [PATCH] Corrected and validated issue where axes are flipped in
 exported TIFs

---
 examples/ilastik3d.py   |  5 +++--
 model_server/ilastik.py | 22 +++++++++++-----------
 model_server/image.py   | 14 +++++++++-----
 tests/test_ilastik.py   |  4 ++--
 tests/test_image.py     | 24 ++++++++++++++++--------
 5 files changed, 41 insertions(+), 28 deletions(-)

diff --git a/examples/ilastik3d.py b/examples/ilastik3d.py
index 0e381475..3f542533 100644
--- a/examples/ilastik3d.py
+++ b/examples/ilastik3d.py
@@ -65,6 +65,7 @@ if __name__ == '__main__':
 
     resp_models = requests.get(uri + 'models')
 
-    raise Exception('x, y axes are currently flipped in pixelmap')
+    # raise Exception('x, y axes are currently flipped in pixelmap')
 
-    server_process.terminate()
\ No newline at end of file
+    server_process.terminate()
+    print('Finished')
\ No newline at end of file
diff --git a/model_server/ilastik.py b/model_server/ilastik.py
index d2490ba9..aad0f38c 100644
--- a/model_server/ilastik.py
+++ b/model_server/ilastik.py
@@ -57,22 +57,22 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel):
         return PixelClassificationWorkflow
 
     def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict):
-        tagged_input_data = vigra.taggedView(input_img.data, 'xycz')
+        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
         dsi = [
             {
                 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
             }
         ]
-        pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [1 x w x h x n]
+        pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
 
         assert(len(pxmaps) == 1, 'ilastik generated more than one pixel map')
 
-        xycz = np.moveaxis(
+        yxcz = np.moveaxis(
             pxmaps[0],
-            [2, 1, 3, 0],
+            [1, 2, 3, 0],
             [0, 1, 2, 3]
         )
-        return InMemoryDataAccessor(data=xycz), {'success': True}
+        return InMemoryDataAccessor(data=yxcz), {'success': True}
 
 class IlastikObjectClassifierModel(IlastikImageToImageModel):
     model_id = 'ilastik_object_classification'
@@ -83,8 +83,8 @@ class IlastikObjectClassifierModel(IlastikImageToImageModel):
         return ObjectClassificationWorkflow
 
     def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
-        tagged_input_data = vigra.taggedView(input_img.data, 'xycz')
-        tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'xycz')
+        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
+        tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
 
         dsi = [
             {
@@ -93,13 +93,13 @@ class IlastikObjectClassifierModel(IlastikImageToImageModel):
             }
         ]
 
-        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)
+        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
 
         assert (len(obmaps) == 1, 'ilastik generated more than one object map')
 
-        xycz = np.moveaxis(
+        yxcz = np.moveaxis(
             obmaps[0],
-            [2, 1, 3, 0],
+            [1, 2, 3, 0],
             [0, 1, 2, 3]
         )
-        return InMemoryDataAccessor(data=xycz), {'success': True}
+        return InMemoryDataAccessor(data=yxcz), {'success': True}
diff --git a/model_server/image.py b/model_server/image.py
index 8cbd37c1..d2b4da24 100644
--- a/model_server/image.py
+++ b/model_server/image.py
@@ -38,6 +38,10 @@ class GenericImageDataAccessor(ABC):
 
     @property
     def data(self):
+        """
+        Return data as 4d with axes in order of Y, X, C, Z
+        :return: np.ndarray
+        """
         return self._data
 
     @property
@@ -46,7 +50,7 @@ class GenericImageDataAccessor(ABC):
 
     @property
     def shape_dict(self):
-        return dict(zip(('X', 'Y', 'C', 'Z'), self.data.shape))
+        return dict(zip(('Y', 'X', 'C', 'Z'), self.data.shape))
 
 class InMemoryDataAccessor(GenericImageDataAccessor):
     def __init__(self, data):
@@ -96,12 +100,12 @@ class CziImageFileAccessor(GenericImageFileAccessor):
 
 def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor) -> bool:
     try:
-        zcxy = np.moveaxis(
-            accessor.data,
-            [3, 2, 1, 0],
+        zcyx= np.moveaxis(
+            accessor.data, # yxcz
+            [3, 2, 0, 1],
             [0, 1, 2, 3]
         )
-        tifffile.imwrite(fpath, zcxy, imagej=True)
+        tifffile.imwrite(fpath, zcyx, imagej=True)
     except:
         raise FileWriteError(f'Unable to write data to file')
     return True
diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py
index 25c876b9..4af941d1 100644
--- a/tests/test_ilastik.py
+++ b/tests/test_ilastik.py
@@ -44,10 +44,10 @@ class TestIlastikPixelClassification(unittest.TestCase):
         w = 512
         h = 256
 
-        input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))
+        input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1))
 
         pxmap, _ = model.infer(input_img)
-        self.assertEqual(pxmap.shape, (w, h, 2, 1))
+        self.assertEqual(pxmap.shape, (h, w, 2, 1))
 
 
     def test_run_pixel_classifier(self):
diff --git a/tests/test_image.py b/tests/test_image.py
index 00a746fe..6c085b6c 100644
--- a/tests/test_image.py
+++ b/tests/test_image.py
@@ -44,7 +44,9 @@ class TestCziImageFileAccess(unittest.TestCase):
         self.assertEqual(cf.data.shape[3], mono.data.shape[2])
 
     def test_conform_data_shorter_than_xycz(self):
-        data = np.random.rand(256, 512, 1)
+        h = 256
+        w = 512
+        data = np.random.rand(h, w, 1)
         acc = InMemoryDataAccessor(data)
         self.assertEqual(
             InMemoryDataAccessor.conform_data(data).shape,
@@ -52,7 +54,7 @@ class TestCziImageFileAccess(unittest.TestCase):
         )
         self.assertEqual(
             acc.shape_dict,
-            {'X': 256, 'Y': 512, 'C': 1, 'Z': 1}
+            {'Y': 256, 'X': 512, 'C': 1, 'Z': 1}
         )
 
     def test_conform_data_longer_than_xycz(self):
@@ -61,7 +63,7 @@ class TestCziImageFileAccess(unittest.TestCase):
             acc = InMemoryDataAccessor(data)
 
 
-    def test_write_multichannel_image_preserves_axes(self):
+    def test_write_multichannel_image_preserve_axes(self):
         h = 256
         w = 512
         c = 3
@@ -69,12 +71,18 @@ class TestCziImageFileAccess(unittest.TestCase):
 
         yxcz = (2**8 * np.random.rand(h, w, c, nz)).astype('uint8')
         acc = InMemoryDataAccessor(yxcz)
+        fp = output_path / f'rand3d.tif'
         self.assertTrue(
-            write_accessor_data_to_file(
-                output_path / f'rand3d.tif',
-                acc
-            )
+            write_accessor_data_to_file(fp, acc)
         )
         # need to sort out x,y flipping since np convention yxcz flips axes in 3d tif
         self.assertEqual(acc.shape_dict['X'], w, acc.shape_dict)
-        self.assertEqual(acc.shape_dict['Y'], h, acc.shape_dict)
\ No newline at end of file
+        self.assertEqual(acc.shape_dict['Y'], h, acc.shape_dict)
+
+        # re-open file and check axes order
+        from tifffile import TiffFile
+        fh = TiffFile(fp)
+        self.assertEqual(len(fh.series), 1)
+        se = fh.series[0]
+        fh_shape_dict = {se.axes[i]: se.shape[i] for i in range(0, len(se.shape))}
+        self.assertEqual(fh_shape_dict, acc.shape_dict, 'Axes are not preserved in TIF output')
-- 
GitLab