Skip to content
Snippets Groups Projects
Commit 11f01a1b authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

PatchStack remembers crop boundaries of individual patches if created from...

PatchStack remembers crop boundaries of individual patches if created from list of non-matching patches
parent 2d91f027
No related branches found
No related tags found
No related merge requests found
......@@ -279,14 +279,21 @@ class PatchStack(InMemoryDataAccessor):
else:
raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}')
self._slices = []
for i in range(0, len(data)):
self._slices.append(tuple([slice(0, c) for c in data[i].shape]))
assert nda.ndim == 5
self._data = nda
def iat(self, i):
return InMemoryDataAccessor(self.data[i, :, :, :, :])
def iat(self, i, crop=False):
if crop:
return InMemoryDataAccessor(self.data[i, :, :, :, :][self._slices[i]])
else:
return InMemoryDataAccessor(self.data[i, :, :, :, :])
def iat_yxcz(self, i):
return self.iat(i)
def iat_yxcz(self, i, crop=False):
return self.iat(i, crop=crop)
@property
def count(self):
......
......@@ -146,6 +146,8 @@ class TestPatchStackAccessor(unittest.TestCase):
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1))
self.assertEqual(acc.shape[1:], acc.iat(0, crop=True).shape)
def test_make_patch_stack_from_list(self):
w = 256
h = 512
......@@ -188,6 +190,11 @@ class TestPatchStackAccessor(unittest.TestCase):
self.assertEqual(acc.iat(0).shape, (h, 2 * w, c, nz))
self.assertEqual(acc.iat_yxcz(0).shape, (h, 2 * w, c, nz))
# test that initial patches are maintained
for i in range(0, acc.count):
self.assertEqual(patches[i].shape, acc.iat(i, crop=True).shape)
self.assertEqual(acc.shape[1:], acc.iat(i, crop=False).shape)
def test_pczyx(self):
w = 256
h = 512
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment