Skip to content
Snippets Groups Projects
Commit b54bfadd authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test coverage of patch stack object classification

parent f21346e9
No related branches found
No related tags found
No related merge requests found
......@@ -189,7 +189,10 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict):
assert segmentation_acc.is_mask()
assert input_acc.chroma == 1
if not input_acc.chroma == 1:
raise InvalidInputImageError('Object classifier expects only monochrome patches')
if not input_acc.nz == 1:
raise InvalidInputImageError('Object classifier expects only 2d patches')
tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx')
......@@ -205,14 +208,12 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
assert len(obmaps) == 1, 'ilastik generated more than one object map'
# for some reason ilastik scrambles these axes to Z(1)YX(1)
assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
yxz = np.moveaxis(
obmaps[0][:, 0, :, :, 0],
[1, 2, 0],
[0, 1, 2]
# for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C
assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1)
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
assert yxz.shape[0:2] == input_acc.hw
assert yxz.shape[2] == input_acc.nz
return PatchStack(data=yxz), {'success': True}
return PatchStack(data=pyxcz), {'success': True}
\ No newline at end of file
......@@ -292,5 +292,9 @@ class TestIlastikObjectClassification(unittest.TestCase):
def test_classify_patches(self):
raw_patches = self.roiset.get_raw_patches()
patch_masks = self.roiset.get_patch_masks()
res = self.object_classifier.infer(raw_patches, patch_masks)
self.assertEqual(0, 1)
res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks)
self.assertEqual(res_patches.count, self.roiset.count)
for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch
unique = np.unique(res_patches.iat(pi))
self.assertEqual(len(unique), 2)
self.assertEqual(unique[0], 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment