From 40be402fb4d097432a30d760d0fd709dccc91c12 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 2 Nov 2023 15:40:12 +0100
Subject: [PATCH] Added 3D patch accessor

---
 extensions/chaeo/accessors.py            | 33 ++++++++++++++++++++++++
 extensions/chaeo/tests/test_accessors.py | 29 +++++++++++++++++----
 2 files changed, 57 insertions(+), 5 deletions(-)

diff --git a/extensions/chaeo/accessors.py b/extensions/chaeo/accessors.py
index 2b23318b..c62a7ff5 100644
--- a/extensions/chaeo/accessors.py
+++ b/extensions/chaeo/accessors.py
@@ -63,6 +63,39 @@ class MonoPatchStackFromFile(MonoPatchStack):
     def fpath(self):
         return self.file_acc.fpath
 
+class PatchStack3D(InMemoryDataAccessor):
+
+    def __init__(self, data):
+        """
+        A sequence of n monochrome 3D images of the same size
+        :param data: a list of np.ndarrays of size YXZ
+        """
+
+        if isinstance(data, list):  # list of YXZ patches
+            nda = np.array(data).squeeze()
+            assert nda.ndim == 4
+            self._data = np.moveaxis(
+                    nda,
+                    [1, 2, 0, 3],
+                    [0, 1, 2, 3]
+            )
+        else:
+            raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}')
+
+    def iat(self, i):
+        return self.data[:, :, i, :]
+
+    def iat_yxcz(self, i):
+        return np.expand_dims(self.iat(i), 2)
+
+    @property
+    def chroma(self):
+        return 1
+
+    @property
+    def count(self):
+        return self.shape[2]
+
 class Error(Exception):
     pass
 
diff --git a/extensions/chaeo/tests/test_accessors.py b/extensions/chaeo/tests/test_accessors.py
index c74fdd86..461c34d7 100644
--- a/extensions/chaeo/tests/test_accessors.py
+++ b/extensions/chaeo/tests/test_accessors.py
@@ -3,7 +3,7 @@ import unittest
 import numpy as np
 
 from conf.testing import monozstackmask
-from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile
+from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile, PatchStack3D
 
 
 
@@ -11,7 +11,7 @@ class TestCziImageFileAccess(unittest.TestCase):
     def setUp(self) -> None:
         pass
 
-    def test_make_patch_stack_from_list(self):
+    def test_make_patch_stack_from_3d_array(self):
         w = 256
         h = 512
         n = 4
@@ -20,11 +20,11 @@ class TestCziImageFileAccess(unittest.TestCase):
         self.assertEqual(acc.hw, (h, w))
         self.assertEqual(acc.make_tczyx().shape, (n, 1, 1, h, w))
 
-    def test_make_patch_stack_from_3d_array(self):
+    def test_make_patch_stack_from_list(self):
         w = 256
         h = 512
         n = 4
-        acc = MonoPatchStack([np.random.rand(h, w) for _ in range(0, 4)])
+        acc = MonoPatchStack([np.random.rand(h, w) for _ in range(0, n)])
         self.assertEqual(acc.count, n)
         self.assertEqual(acc.hw, (h, w))
         self.assertEqual(acc.make_tczyx().shape, (n, 1, 1, h, w))
@@ -50,4 +50,23 @@ class TestCziImageFileAccess(unittest.TestCase):
         h = 512
         n = 4
         acc = MonoPatchStack([np.random.rand(h, w) for _ in range(0, 4)])
-        self.assertEqual(acc.iat_yxcz(0).shape, (h, w, 1, 1))
\ No newline at end of file
+        self.assertEqual(acc.iat_yxcz(0).shape, (h, w, 1, 1))
+
+    def test_make_3d_patch_stack_from_list(self):
+        w = 256
+        h = 512
+        nz = 5
+        n = 4
+        acc = PatchStack3D([np.random.rand(h, w, nz) for _ in range(0, n)])
+        self.assertEqual(acc.count, n)
+        self.assertEqual(acc.hw, (h, w))
+        self.assertEqual(acc.chroma, 1)
+        self.assertEqual(acc.iat(0).shape, (h, w, nz))
+
+    def test_3d_patch_as_yxcz_array(self):
+        w = 256
+        h = 512
+        nz = 5
+        n = 4
+        acc = PatchStack3D([np.random.rand(h, w, nz) for _ in range(0, n)])
+        self.assertEqual(acc.iat_yxcz(0).shape, (h, w, 1, nz))
\ No newline at end of file
-- 
GitLab