-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
test_accessors.py 9.70 KiB
import unittest
import numpy as np
from model_server.base.accessors import PatchStack, make_patch_stack_from_file, FileNotFoundError
from tests.base._conf import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask
from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor
def _random_int(*args):
return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
pass
def test_tiffile_is_correct_shape(self):
tf = generate_file_accessor(tifffile['path'])
self.assertIsInstance(tf, TifSingleSeriesFileAccessor)
self.assertEqual(tf.shape_dict['Y'], tifffile['h'])
self.assertEqual(tf.shape_dict['X'], tifffile['w'])
self.assertEqual(tf.chroma, tifffile['c'])
self.assertTrue(tf.is_3d())
self.assertEqual(len(tf.data.shape), 4)
self.assertEqual(tf.shape[0], tifffile['h'])
self.assertEqual(tf.shape[1], tifffile['w'])
self.assertEqual(tf.get_axis('x'), 1)
def test_czifile_is_correct_shape(self):
cf = CziImageFileAccessor(czifile['path'])
self.assertEqual(cf.shape_dict['Y'], czifile['h'])
self.assertEqual(cf.shape_dict['X'], czifile['w'])
self.assertEqual(cf.chroma, czifile['c'])
self.assertFalse(cf.is_3d())
self.assertEqual(len(cf.data.shape), 4)
self.assertEqual(cf.shape[0], czifile['h'])
self.assertEqual(cf.shape[1], czifile['w'])
def test_get_single_channel_from_zstack(self):
w = 256
h = 512
nc = 4
nz = 11
c = 3
cf = InMemoryDataAccessor(_random_int(h, w, nc, nz))
sc = cf.get_mono(c)
self.assertEqual(sc.shape, (h, w, 1, nz))
def test_get_single_channel_mip_from_zstack(self):
w = 256
h = 512
nc = 4
nz = 11
c = 3
cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz))
sc = cf.get_mono(c, mip=True)
self.assertEqual(sc.shape, (h, w, 1, 1))
def test_write_single_channel_tif(self):
ch = 4
cf = CziImageFileAccessor(czifile['path'])
mono = cf.get_mono(ch)
self.assertTrue(
write_accessor_data_to_file(
output_path / f'{cf.fpath.stem}_ch{ch}.tif',
mono
)
)
self.assertEqual(cf.data.shape[0:2], mono.data.shape[0:2])
self.assertEqual(cf.data.shape[3], mono.data.shape[2])
def test_write_two_channel_png(self):
from model_server.base.process import resample_to_8bit
cf = CziImageFileAccessor(czifile['path'])
acc = cf.get_channels([0, 1])
opa = output_path / f'{cf.fpath.stem}_2ch.png'
acc_out = acc.apply(resample_to_8bit)
self.assertTrue(
write_accessor_data_to_file(opa, acc_out)
)
val_acc = generate_file_accessor(opa)
self.assertEqual(val_acc.chroma, 3)
self.assertTrue(np.all(val_acc.data[:, :, -1, :] == 0)) # final channel is blank
def test_conform_data_shorter_than_xycz(self):
h = 256
w = 512
data = _random_int(h, w, 1)
acc = InMemoryDataAccessor(data)
self.assertEqual(
InMemoryDataAccessor.conform_data(data).shape,
(256, 512, 1, 1)
)
self.assertEqual(
acc.shape_dict,
{'Y': 256, 'X': 512, 'C': 1, 'Z': 1}
)
def test_conform_data_longer_than_xycz(self):
data = _random_int(256, 512, 12, 8, 3)
with self.assertRaises(DataShapeError):
acc = InMemoryDataAccessor(data)
def test_write_multichannel_image_preserve_axes(self):
h = 256
w = 512
c = 3
nz = 10
yxcz = _random_int(h, w, c, nz)
acc = InMemoryDataAccessor(yxcz)
fp = output_path / f'rand3d.tif'
self.assertTrue(
write_accessor_data_to_file(fp, acc)
)
# need to sort out x,y flipping since np convention yxcz flips axes in 3d tif
self.assertEqual(acc.shape_dict['X'], w, acc.shape_dict)
self.assertEqual(acc.shape_dict['Y'], h, acc.shape_dict)
# re-open file and check axes order
from tifffile import TiffFile
fh = TiffFile(fp)
self.assertEqual(len(fh.series), 1)
se = fh.series[0]
fh_shape_dict = {se.axes[i]: se.shape[i] for i in range(0, len(se.shape))}
self.assertEqual(fh_shape_dict, acc.shape_dict, 'Axes are not preserved in TIF output')
def test_read_png(self, pngfile=rgbpngfile):
acc = PngFileAccessor(pngfile['path'])
self.assertEqual(acc.hw, (pngfile['h'], pngfile['w']))
self.assertEqual(acc.chroma, pngfile['c'])
self.assertEqual(acc.nz, 1)
def test_read_mono_png(self):
return self.test_read_png(pngfile=monopngfile)
def test_read_zstack_mono_mask(self):
acc = generate_file_accessor(monozstackmask['path'])
self.assertTrue(acc.is_mask())
def test_read_in_pixel_scale_from_czi(self):
cf = CziImageFileAccessor(czifile['path'])
pxs = cf.pixel_scale_in_micrometers
self.assertAlmostEqual(pxs['X'], czifile['um_per_pixel'], places=3)
class TestPatchStackAccessor(unittest.TestCase):
def setUp(self) -> None:
pass
def test_make_patch_stack_from_3d_array(self):
w = 256
h = 512
n = 4
acc = PatchStack(_random_int(n, h, w, 1, 1))
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1))
self.assertEqual(acc.shape[1:], acc.iat(0, crop=True).shape)
def test_make_patch_stack_from_list(self):
w = 256
h = 512
n = 4
acc = PatchStack([_random_int(h, w, 1, 1) for _ in range(0, n)])
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1))
return acc
def test_make_patch_stack_from_file(self):
h = monozstackmask['h']
w = monozstackmask['w']
c = monozstackmask['c']
n = monozstackmask['z']
acc = make_patch_stack_from_file(monozstackmask['path'])
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.count, n)
self.assertEqual(acc.pyxcz.shape, (n, h, w, c, 1))
def test_raises_filenotfound(self):
with self.assertRaises(FileNotFoundError):
acc = make_patch_stack_from_file('c:/fake/file/name.tif')
def test_make_3d_patch_stack_from_nonuniform_list(self):
w = 256
h = 512
c = 1
nz = 5
n = 4
patches = [_random_int(h, w, c, nz) for _ in range(0, n)]
patches.append(_random_int(h, 2 * w, c, nz))
acc = PatchStack(patches)
self.assertEqual(acc.count, n + 1)
self.assertEqual(acc.hw, (h, 2 * w))
self.assertEqual(acc.chroma, c)
self.assertEqual(acc.iat(0).shape, (h, 2 * w, c, nz))
self.assertEqual(acc.iat_yxcz(0).shape, (h, 2 * w, c, nz))
# test that initial patches are maintained
for i in range(0, acc.count):
self.assertEqual(patches[i].shape, acc.iat(i, crop=True).shape)
self.assertEqual(acc.shape[1:], acc.iat(i, crop=False).shape)
def test_make_3d_patch_stack_from_list_force_long_dim(self):
def _r(h, w):
return np.random.randint(0, 2 ** 8, size=(h, w, 1, 1), dtype='uint8')
patches = [_r(256, 128), _r(128, 256), _r(512, 10), _r(10, 512)]
acc_ref = PatchStack(patches, force_ydim_longest=False)
self.assertEqual(acc_ref.hw, (512, 512))
self.assertEqual(acc_ref.iat(-1, crop=False).hw, (512, 512))
self.assertEqual(acc_ref.iat(-1, crop=True).hw, (10, 512))
acc_rot = PatchStack(patches, force_ydim_longest=True)
self.assertEqual(acc_rot.hw, (512, 128))
self.assertEqual(acc_rot.iat(-1, crop=False).hw, (512, 128))
self.assertEqual(acc_rot.iat(-1, crop=True).hw, (512, 10))
nda_rot_rot = np.rot90(acc_rot.iat(-1, crop=True).data, axes=(1, 0))
nda_ref = acc_ref.iat(-1, crop=True).data
self.assertTrue(np.all(nda_ref == nda_rot_rot))
self.assertLess(acc_rot.data.size, acc_ref.data.size)
def test_pczyx(self):
w = 256
h = 512
n = 4
nz = 15
nc = 3
acc = PatchStack(_random_int(n, h, w, nc, nz))
self.assertEqual(acc.count, n)
self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w))
self.assertEqual(acc.hw, (h, w))
return acc
def test_get_one_channel(self):
acc = self.test_pczyx()
mono = acc.get_mono(channel=1)
for a in 'PXYZ':
self.assertEqual(mono.shape_dict[a], acc.shape_dict[a])
self.assertEqual(mono.shape_dict['C'], 1)
def test_get_multiple_channels(self):
acc = self.test_pczyx()
channels = [0, 1]
mcacc = acc.get_channels(channels=channels)
for a in 'PXYZ':
self.assertEqual(mcacc.shape_dict[a], acc.shape_dict[a])
self.assertEqual(mcacc.shape_dict['C'], len(channels))
def test_get_one_channel_mip(self):
acc = self.test_pczyx()
mono_mip = acc.get_mono(channel=1, mip=True)
for a in 'PXY':
self.assertEqual(mono_mip.shape_dict[a], acc.shape_dict[a])
for a in 'CZ':
self.assertEqual(mono_mip.shape_dict[a], 1)
def test_export_pczyx_patch_hyperstack(self):
acc = self.test_pczyx()
fp = output_path / 'patch_hyperstack.tif'
acc.export_pyxcz(fp)
acc2 = make_patch_stack_from_file(fp)
self.assertEqual(acc.shape, acc2.shape)