From b0847605bf7d00177938a4a26b2e288089be1f38 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 2 Nov 2023 15:50:09 +0100
Subject: [PATCH] Expanded 3d patch stack as 3d multichannel patch stack

---
 extensions/chaeo/accessors.py            | 47 +++++++++++++++---------
 extensions/chaeo/tests/test_accessors.py | 14 ++++---
 2 files changed, 38 insertions(+), 23 deletions(-)

diff --git a/extensions/chaeo/accessors.py b/extensions/chaeo/accessors.py
index c62a7ff5..f606470e 100644
--- a/extensions/chaeo/accessors.py
+++ b/extensions/chaeo/accessors.py
@@ -63,38 +63,51 @@ class MonoPatchStackFromFile(MonoPatchStack):
     def fpath(self):
         return self.file_acc.fpath
 
-class PatchStack3D(InMemoryDataAccessor):
+class Multichannel3dPatchStack(InMemoryDataAccessor):
 
     def __init__(self, data):
         """
-        A sequence of n monochrome 3D images of the same size
-        :param data: a list of np.ndarrays of size YXZ
+        A sequence of n (generally) color 3D images of the same size
+        :param data: a list of np.ndarrays of size YXCZ
         """
 
-        if isinstance(data, list):  # list of YXZ patches
-            nda = np.array(data).squeeze()
-            assert nda.ndim == 4
-            self._data = np.moveaxis(
-                    nda,
-                    [1, 2, 0, 3],
-                    [0, 1, 2, 3]
-            )
+        if isinstance(data, list):  # list of YXCZ patches
+            nda = np.array(data)
+            assert nda.ndim == 5
+            # self._data = np.moveaxis( # pos-YXCZ
+            #         nda,
+            #         [0, 1, 2, 0, 3],
+            #         [0, 1, 2, 3]
+            # )
+            self._data = nda
         else:
             raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}')
 
     def iat(self, i):
-        return self.data[:, :, i, :]
+        return self.data[i, :, :, :, :]
 
     def iat_yxcz(self, i):
-        return np.expand_dims(self.iat(i), 2)
+        return self.iat(i)
 
     @property
-    def chroma(self):
-        return 1
+    def count(self):
+        return self.shape_dict['P']
 
     @property
-    def count(self):
-        return self.shape[2]
+    def data(self):
+        """
+        Return data as 5d with axes in order of pos, Y, X, C, Z
+        :return: np.ndarray
+        """
+        return self._data
+
+    @property
+    def shape(self):
+        return self._data.shape
+
+    @property
+    def shape_dict(self):
+        return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
 
 class Error(Exception):
     pass
diff --git a/extensions/chaeo/tests/test_accessors.py b/extensions/chaeo/tests/test_accessors.py
index 461c34d7..c2194581 100644
--- a/extensions/chaeo/tests/test_accessors.py
+++ b/extensions/chaeo/tests/test_accessors.py
@@ -3,7 +3,7 @@ import unittest
 import numpy as np
 
 from conf.testing import monozstackmask
-from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile, PatchStack3D
+from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile, Multichannel3dPatchStack
 
 
 
@@ -55,18 +55,20 @@ class TestCziImageFileAccess(unittest.TestCase):
     def test_make_3d_patch_stack_from_list(self):
         w = 256
         h = 512
+        c = 1
         nz = 5
         n = 4
-        acc = PatchStack3D([np.random.rand(h, w, nz) for _ in range(0, n)])
+        acc = Multichannel3dPatchStack([np.random.rand(h, w, c, nz) for _ in range(0, n)])
         self.assertEqual(acc.count, n)
         self.assertEqual(acc.hw, (h, w))
-        self.assertEqual(acc.chroma, 1)
-        self.assertEqual(acc.iat(0).shape, (h, w, nz))
+        self.assertEqual(acc.chroma, c)
+        self.assertEqual(acc.iat(0).shape, (h, w, c, nz))
 
     def test_3d_patch_as_yxcz_array(self):
         w = 256
         h = 512
         nz = 5
+        c = 1
         n = 4
-        acc = PatchStack3D([np.random.rand(h, w, nz) for _ in range(0, n)])
-        self.assertEqual(acc.iat_yxcz(0).shape, (h, w, 1, nz))
\ No newline at end of file
+        acc = Multichannel3dPatchStack([np.random.rand(h, w, c, nz) for _ in range(0, n)])
+        self.assertEqual(acc.iat_yxcz(0).shape, (h, w, c, nz))
\ No newline at end of file
-- 
GitLab