From 30e33badd5ce6164aa30dce079976a54690433a8 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 27 Sep 2023 16:02:02 +0200
Subject: [PATCH] Added generic Tif accessor

---
 conf/testing.py                            | 10 ++++++++
 model_server/accessors.py                  | 28 ++++++++++++++++++++++
 tests/{test_image.py => test_accessors.py} | 14 +++++++++--
 3 files changed, 50 insertions(+), 2 deletions(-)
 rename tests/{test_image.py => test_accessors.py} (82%)

diff --git a/conf/testing.py b/conf/testing.py
index 55408381..a089ccc5 100644
--- a/conf/testing.py
+++ b/conf/testing.py
@@ -12,6 +12,16 @@ czifile = {
     'z': 1,
 }
 
+filename = 'zmask-test-stack.tif'
+tifffile = {
+    'filename': filename,
+    'path': root / filename,
+    'w': 512,
+    'h': 512,
+    'c': 2,
+    'z': 7,
+}
+
 ilastik = {
     'pixel_classifier': 'demo_px.ilp',
     'object_classifier': 'demo_obj.ilp',
diff --git a/model_server/accessors.py b/model_server/accessors.py
index d2b4da24..1a9b6b0b 100644
--- a/model_server/accessors.py
+++ b/model_server/accessors.py
@@ -67,6 +67,34 @@ class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded
             raise FileAccessorError(f'Could not find file at {fpath}')
         self.fpath = fpath
 
+class TifSingleSeriesFileAccessor(GenericImageFileAccessor):
+    def __init__(self, fpath: Path):
+        super().__init__(fpath)
+
+        try:
+            tf = tifffile.TiffFile(fpath)
+            self.tf = tf
+        except Exception:
+            FileAccessorError(f'Unable to access data in {fpath}')
+
+        if len(tf.series) != 1:
+            raise DataShapeError(f'Expect only one series in {fpath}')
+
+        se = tf.series[0]
+        sd = {ch: se.shape[se.axes.index(ch)] for ch in se.axes}
+
+        idx = {k: sd[k] for k in ['Y', 'X', 'C', 'Z']}
+        yxcz = np.moveaxis(
+            se.asarray(),
+            [se.axes.index(ch) for ch in idx],
+            [0, 1, 2, 3]
+        )
+
+        self._data = self.conform_data(yxcz.reshape(yxcz.shape[0:4]))
+
+    def __del__(self):
+        self.tf.close()
+
 class CziImageFileAccessor(GenericImageFileAccessor):
     """
     Image that is stored in a Zeiss .CZI file; may be multi-channel, and/or a z-stack,
diff --git a/tests/test_image.py b/tests/test_accessors.py
similarity index 82%
rename from tests/test_image.py
rename to tests/test_accessors.py
index 3d590990..bb5c95cd 100644
--- a/tests/test_image.py
+++ b/tests/test_accessors.py
@@ -2,14 +2,24 @@ import unittest
 
 import numpy as np
 
-from conf.testing import czifile, output_path
-from model_server.accessors import CziImageFileAccessor, DataShapeError, InMemoryDataAccessor, write_accessor_data_to_file
+from conf.testing import czifile, output_path, tifffile
+from model_server.accessors import CziImageFileAccessor, DataShapeError, InMemoryDataAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor
 
 class TestCziImageFileAccess(unittest.TestCase):
 
     def setUp(self) -> None:
         pass
 
+    def test_tiffile_is_correct_shape(self):
+        tf = TifSingleSeriesFileAccessor(tifffile['path'])
+        self.assertEqual(tf.shape_dict['Y'], tifffile['h'])
+        self.assertEqual(tf.shape_dict['X'], tifffile['w'])
+        self.assertEqual(tf.chroma, tifffile['c'])
+        self.assertTrue(tf.is_3d())
+        self.assertEqual(len(tf.data.shape), 4)
+        self.assertEqual(tf.shape[0], tifffile['h'])
+        self.assertEqual(tf.shape[1], tifffile['w'])
+
     def test_czifile_is_correct_shape(self):
         cf = CziImageFileAccessor(czifile['path'])
         self.assertEqual(cf.shape_dict['Y'], czifile['h'])
-- 
GitLab