diff --git a/model_server/image.py b/model_server/image.py index c700ea66bb00e810fb4acc5887b20a7d8e5679e2..5d16d980f26cf037bcab7e72379a722eba5ec0f0 100644 --- a/model_server/image.py +++ b/model_server/image.py @@ -24,7 +24,7 @@ class GenericImageDataAccessor(ABC): @staticmethod def conform_data(data): - if len(data.shape) > 4: + if len(data.shape) > 4 or (0 in data.shape): raise DataShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z: {data.shape}') ones = [1 for i in range(0, 4 - len(data.shape))] return data.reshape(*data.shape, *ones) @@ -33,7 +33,8 @@ class GenericImageDataAccessor(ABC): return True if self.shape_dict['Z'] > 1 else False def get_one_channel_data (self, channel: int): - return InMemoryDataAccessor(self.data[:, :, int(channel), :]) + c = int(channel) + return InMemoryDataAccessor(self.data[:, :, c:(c+1), :]) @property def data(self): diff --git a/tests/test_image.py b/tests/test_image.py index 7f6df8220c88d386ee93c2bfa1b623b634aeb527..83d25ba1488080aa657c7bfe212c84a87387a1f0 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -20,6 +20,16 @@ class TestCziImageFileAccess(unittest.TestCase): self.assertEqual(cf.shape[0], czifile['h']) self.assertEqual(cf.shape[1], czifile['w']) + def test_get_single_channel_from_zstack(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + c = 3 + cf = InMemoryDataAccessor(np.random.rand(w, h, nc, nz)) + sc = cf.get_one_channel_data(c) + self.assertEqual(sc.shape, (w, h, 1, nz)) + def test_write_single_channel_tif(self): ch = 4 cf = CziImageFileAccessor(czifile['path'])