Skip to content
Snippets Groups Projects
Commit b54bfadd authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test coverage of patch stack object classification

parent f21346e9
No related branches found
No related tags found
No related merge requests found
...@@ -189,7 +189,10 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): ...@@ -189,7 +189,10 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict):
assert segmentation_acc.is_mask() assert segmentation_acc.is_mask()
assert input_acc.chroma == 1 if not input_acc.chroma == 1:
raise InvalidInputImageError('Object classifier expects only monochrome patches')
if not input_acc.nz == 1:
raise InvalidInputImageError('Object classifier expects only 2d patches')
tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx')
...@@ -205,14 +208,12 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): ...@@ -205,14 +208,12 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
assert len(obmaps) == 1, 'ilastik generated more than one object map' assert len(obmaps) == 1, 'ilastik generated more than one object map'
# for some reason ilastik scrambles these axes to Z(1)YX(1) # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C
assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1) assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1)
yxz = np.moveaxis( pyxcz = np.moveaxis(
obmaps[0][:, 0, :, :, 0], obmaps[0],
[1, 2, 0], [0, 1, 2, 3, 4],
[0, 1, 2] [0, 4, 1, 2, 3]
) )
assert yxz.shape[0:2] == input_acc.hw return PatchStack(data=pyxcz), {'success': True}
assert yxz.shape[2] == input_acc.nz \ No newline at end of file
return PatchStack(data=yxz), {'success': True}
...@@ -292,5 +292,9 @@ class TestIlastikObjectClassification(unittest.TestCase): ...@@ -292,5 +292,9 @@ class TestIlastikObjectClassification(unittest.TestCase):
def test_classify_patches(self): def test_classify_patches(self):
raw_patches = self.roiset.get_raw_patches() raw_patches = self.roiset.get_raw_patches()
patch_masks = self.roiset.get_patch_masks() patch_masks = self.roiset.get_patch_masks()
res = self.object_classifier.infer(raw_patches, patch_masks) res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks)
self.assertEqual(0, 1) self.assertEqual(res_patches.count, self.roiset.count)
for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch
unique = np.unique(res_patches.iat(pi))
self.assertEqual(len(unique), 2)
self.assertEqual(unique[0], 0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment