From 6b7cdc2ee4c13b5debdc44cb61947ced87eabe0e Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 27 Nov 2024 13:33:32 +0100
Subject: [PATCH] Added alternate constructor for RoiSet from either YX or YXZ
 monochromatic data

---
 model_server/base/accessors.py |  9 +++++++++
 tests/base/test_accessors.py   | 23 +++++++++++++++++++++--
 2 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index 7f1e763d..6dd63bc1 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -223,6 +223,15 @@ class InMemoryDataAccessor(GenericImageDataAccessor):
         self._data = self.conform_data(data)
         self.lazy = False
 
+    @classmethod
+    def from_mono(cls, data):
+        if len(data.shape) == 2: # interpret as YX
+            return cls(np.expand_dims(data, (2, 3)))
+        if len(data.shape) == 3:
+            return cls(np.expand_dims(data, 2))
+        else:
+            raise InvalidDataShape(f'Expecting either YX or YXZ monochromatic data')
+
 class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded from a file
     def __init__(self, fpath: Path, lazy=False):
         """
diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py
index cdaa8345..f216f0ca 100644
--- a/tests/base/test_accessors.py
+++ b/tests/base/test_accessors.py
@@ -114,6 +114,25 @@ class TestCziImageFileAccess(unittest.TestCase):
         acc = InMemoryDataAccessor(_random_int(h, w, nc, nz))
         self.assertEqual(acc.get_mono(0).data_mono.shape, (h, w, nz))
 
+    def test_make_from_mono_2d(self):
+        w = 256
+        h = 512
+        nda = _random_int(h, w)
+        acc = InMemoryDataAccessor.from_mono(nda)
+        self.assertEqual(acc.chroma, 1)
+        self.assertEqual(acc.hw, (h, w))
+        self.assertEqual(acc.nz, 1)
+
+    def test_make_from_mono_3d(self):
+        w = 256
+        h = 512
+        nz = 11
+        nda = _random_int(h, w, nz)
+        acc = InMemoryDataAccessor.from_mono(nda)
+        self.assertEqual(acc.chroma, 1)
+        self.assertEqual(acc.hw, (h, w))
+        self.assertEqual(acc.nz, nz)
+
     def test_crop_yx(self):
         w = 256
         h = 512
@@ -320,8 +339,8 @@ class TestPatchStackAccessor(unittest.TestCase):
 
         # test that this persists after channel selection
         for i in range(0, acc.count):
-            self.assertEqual(patches[i].shape[0:2], acc.get_channels([0]).iat_data(i, crop=True).shape[0:2])
-            self.assertEqual(patches[i].shape[3], acc.get_channels([0]).iat_data(i, crop=True).shape[3])
+            self.assertEqual(patches[i].shape[0:2], acc.get_channels([0]).iat(i, crop=True).shape[0:2])
+            self.assertEqual(patches[i].shape[3], acc.get_channels([0]).iat(i, crop=True).shape[3])
 
     def test_write_nonuniform_patches(self):
         w = 256
-- 
GitLab