From 6b7cdc2ee4c13b5debdc44cb61947ced87eabe0e Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 27 Nov 2024 13:33:32 +0100 Subject: [PATCH] Added alternate constructor for RoiSet from either YX or YXZ monochromatic data --- model_server/base/accessors.py | 9 +++++++++ tests/base/test_accessors.py | 23 +++++++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 7f1e763d..6dd63bc1 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -223,6 +223,15 @@ class InMemoryDataAccessor(GenericImageDataAccessor): self._data = self.conform_data(data) self.lazy = False + @classmethod + def from_mono(cls, data): + if len(data.shape) == 2: # interpret as YX + return cls(np.expand_dims(data, (2, 3))) + if len(data.shape) == 3: + return cls(np.expand_dims(data, 2)) + else: + raise InvalidDataShape(f'Expecting either YX or YXZ monochromatic data') + class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded from a file def __init__(self, fpath: Path, lazy=False): """ diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index cdaa8345..f216f0ca 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -114,6 +114,25 @@ class TestCziImageFileAccess(unittest.TestCase): acc = InMemoryDataAccessor(_random_int(h, w, nc, nz)) self.assertEqual(acc.get_mono(0).data_mono.shape, (h, w, nz)) + def test_make_from_mono_2d(self): + w = 256 + h = 512 + nda = _random_int(h, w) + acc = InMemoryDataAccessor.from_mono(nda) + self.assertEqual(acc.chroma, 1) + self.assertEqual(acc.hw, (h, w)) + self.assertEqual(acc.nz, 1) + + def test_make_from_mono_3d(self): + w = 256 + h = 512 + nz = 11 + nda = _random_int(h, w, nz) + acc = InMemoryDataAccessor.from_mono(nda) + self.assertEqual(acc.chroma, 1) + self.assertEqual(acc.hw, (h, w)) + self.assertEqual(acc.nz, nz) + def test_crop_yx(self): w = 256 h = 512 @@ -320,8 +339,8 @@ class TestPatchStackAccessor(unittest.TestCase): # test that this persists after channel selection for i in range(0, acc.count): - self.assertEqual(patches[i].shape[0:2], acc.get_channels([0]).iat_data(i, crop=True).shape[0:2]) - self.assertEqual(patches[i].shape[3], acc.get_channels([0]).iat_data(i, crop=True).shape[3]) + self.assertEqual(patches[i].shape[0:2], acc.get_channels([0]).iat(i, crop=True).shape[0:2]) + self.assertEqual(patches[i].shape[3], acc.get_channels([0]).iat(i, crop=True).shape[3]) def test_write_nonuniform_patches(self): w = 256 -- GitLab