From 7f82c04907eaca1355bb9da75eca72543f2ae46f Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 30 Aug 2023 10:44:40 +0200
Subject: [PATCH] Pixel map generation passes test, confirmed image orientation
 and format are correct

---
 model_server/ilastik.py | 24 ++++++++++++-----------
 model_server/image.py   |  7 ++++++-
 tests/test_ilastik.py   | 43 +++++++++++++++++++++++++++++------------
 3 files changed, 50 insertions(+), 24 deletions(-)

diff --git a/model_server/ilastik.py b/model_server/ilastik.py
index 7be2269c..aadfad6a 100644
--- a/model_server/ilastik.py
+++ b/model_server/ilastik.py
@@ -7,6 +7,7 @@ from ilastik.workflows.pixelClassification import PixelClassificationWorkflow
 from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflow
 
 import numpy as np
+import vigra
 
 from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor
 from model_server.model import ImageToImageModel, ParameterExpectedError
@@ -14,11 +15,6 @@ from model_server.model import ImageToImageModel, ParameterExpectedError
 
 class IlastikImageToImageModel(ImageToImageModel):
 
-    # workflows = {
-    #     'pixel_classification': PixelClassificationWorkflow,
-    #     'object_classification': ObjectClassificationWorkflow,
-    # }
-
     def __init__(self, params, autoload=True):
         if 'project_file' not in params or not os.path.exists(params['project_file']):
             raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
@@ -56,20 +52,26 @@ class IlastikImageToImageModel(ImageToImageModel):
             return workflow.objectClassificationApplet.topLevelOperator
 
 
-    # def infer(self, img, channel=None) -> (np.ndarray, dict):
-    #     assert self.operator.Classifier.ready()
-
 class IlastikPixelClassifierModel(IlastikImageToImageModel):
     workflow = PixelClassificationWorkflow
 
     def infer(self, input_img: GenericImageDataAccessor, channel=None) -> (np.ndarray, dict):
+        tagged_input_data = vigra.taggedView(input_img.data, 'xycz')
         dsi = [
             {
-                'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img.data[0:3]),
+                'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
             }
         ]
-        pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)
-        return InMemoryDataAccessor(data=pxmaps[0])
+        pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [1 x w x h x n]
+
+        assert(len(pxmaps) == 1, 'ilastik generated more than on pixel map')
+
+        xycz = np.moveaxis(
+            pxmaps[0],
+            [2, 1, 3, 0],
+            [0, 1, 2, 3]
+        )
+        return InMemoryDataAccessor(data=xycz)
 
 
 
diff --git a/model_server/image.py b/model_server/image.py
index 863c2ec6..ca015a9c 100644
--- a/model_server/image.py
+++ b/model_server/image.py
@@ -95,7 +95,12 @@ class CziImageFileAccessor(GenericImageFileAccessor):
 
 def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor) -> bool:
     try:
-        tifffile.imwrite(fpath, accessor.data)
+        zcxy = np.moveaxis(
+            accessor.data,
+            [3, 2, 0, 1],
+            [0, 1, 2, 3]
+        )
+        tifffile.imwrite(fpath, zcxy, imagej=True)
     except:
         raise FileWriteError(f'Unable to write data to file')
     return True
diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py
index 94a7cbbb..ac387a41 100644
--- a/tests/test_ilastik.py
+++ b/tests/test_ilastik.py
@@ -1,6 +1,9 @@
 import unittest
+
+import numpy as np
+
 from conf.testing import czifile, ilastik, output_path
-from model_server.image import CziImageFileAccessor, write_accessor_data_to_file
+from model_server.image import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.ilastik import IlastikPixelClassifierModel
 
 class TestIlastikPixelClassification(unittest.TestCase):
@@ -18,21 +21,37 @@ class TestIlastikPixelClassification(unittest.TestCase):
         with self.assertRaises(io.UnsupportedOperation):
             faulthandler.enable(file=sys.stdout)
 
+    def test_run_pixel_classifier_on_random_data(self):
+        model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
+        w = 512
+        h = 256
+
+        input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))
+
+        pxmap = model.infer(input_img)
+        self.assertEqual(pxmap.shape, (w, h, 2, 1))
+
     def test_run_pixel_classifier(self):
         channel = 2
         model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
         cf = CziImageFileAccessor(czifile['path'])
-        pxmap = model.infer(cf.get_one_channel_data(channel))
+        mono_image = cf.get_one_channel_data(channel)
+
+        self.assertEqual(mono_image.shape_dict['X'], czifile['w'])
+        self.assertEqual(mono_image.shape_dict['Y'], czifile['h'])
+        self.assertEqual(mono_image.shape_dict['C'], 1)
+        self.assertEqual(mono_image.shape_dict['Z'], 1)
+
+        pxmap = model.infer(mono_image)
 
+        self.assertEqual(pxmap.shape[0:2], cf.shape[0:2])
+        self.assertEqual(pxmap.shape_dict['C'], 2)
+        self.assertEqual(pxmap.shape_dict['Z'], 1)
         print(pxmap.shape_dict)
+        self.assertTrue(
+            write_accessor_data_to_file(
+                output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
+                pxmap
+            )
+        )
 
-        # self.assertEqual(pxmap.shape[0:2], cf.shape[0:2])
-        # self.assertEqual(pxmap.shape_dict['C'], 2)
-        # self.assertEqual(pxmap.shape_dict['Z'], 1)
-        #
-        # self.assertTrue(
-        #     write_accessor_data_to_file(
-        #         output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
-        #         pxmap
-        #     )
-        # )
\ No newline at end of file
-- 
GitLab