Skip to content
Snippets Groups Projects
Commit ba81ccf9 authored by Constantin Pape's avatar Constantin Pape
Browse files

Update validation code

parent d5656cde
No related branches found
No related tags found
1 merge request!8Segmentation validation and correction
import h5py import h5py
from heimdall import view, to_source from heimdall import view, to_source
from scripts.segmentation.validation import eval_cells, eval_nuclei, get_ignore_seg_ids
def check_cell_evaluation(): def check_cell_evaluation():
from scripts.segmentation.validation.eval_cells import (eval_slice, from scripts.segmentation.validation.eval_cells import (eval_slice,
get_ignore_seg_ids,
get_bounding_box) get_bounding_box)
praw = '../../data' praw = '../../data/rawdata/sbem-6dpf-1-whole-raw.h5'
pseg = '../../data' pseg = '../../data/0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5'
pann = '../../data' pann = '../../data/rawdata/evaluation/validation_annotations.h5'
table_path = '../../data' table_path = '../../data/0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv'
ignore_seg_ids = get_ignore_seg_ids(table_path) ignore_seg_ids = get_ignore_seg_ids(table_path)
with h5py.File(pseg, 'r') as fseg, h5py.File(pann, 'r') as fann: with h5py.File(pseg, 'r') as fseg, h5py.File(pann, 'r') as fann:
ds_seg = fseg['t00000/s00/0/cells'] ds_seg = fseg['t00000/s00/0/cells']
ds_ann = fann['xy/'] ds_ann = fann['xy/1000']
print("Run evaluation ...") print("Run evaluation ...")
res, masks = eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius=16, res, masks = eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius=16,
return_maksks=True) return_masks=True)
fm, fs = masks['false_merges'], masks['false_splits'] fm, fs = masks['merges'], masks['splits']
print() print()
print("Eval result") print("Eval result")
print(res) print(res)
...@@ -30,18 +30,38 @@ def check_cell_evaluation(): ...@@ -30,18 +30,38 @@ def check_cell_evaluation():
print("Load raw data ...") print("Load raw data ...")
bb = get_bounding_box(ds_ann) bb = get_bounding_box(ds_ann)
with h5py.File(praw, 'r') as f: with h5py.File(praw, 'r') as f:
raw = f['t00000/s00/1/cells'][bb] raw = f['t00000/s00/1/cells'][bb].squeeze()
print("Load seg data ...") print("Load seg data ...")
seg = ds_seg[bb].squeeze() seg = ds_seg[bb].squeeze().astype('uint32')
view(to_source(raw, name='raw'), to_source(seg, name='seg'), view(to_source(raw, name='raw'), to_source(seg, name='seg'),
to_source(fm, name='merges'), to_source(fs, name='splits')) to_source(fm, name='merges'), to_source(fs, name='splits'))
# def check_nucleus_evaluation(): def eval_all_cells():
# eval_nuclei() pseg = '../../data/0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5'
pann = '../../data/rawdata/evaluation/validation_annotations.h5'
table_path = '../../data/0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv'
ignore_seg_ids = get_ignore_seg_ids(table_path)
res = eval_cells(pseg, 't00000/s00/0/cells', pann,
ignore_seg_ids=ignore_seg_ids)
print("Eval result:")
print(res)
# TODO
def check_nucleus_evaluation():
pass
# TODO
def eval_all_nulcei():
eval_nuclei()
if __name__ == '__main__': if __name__ == '__main__':
check_cell_evaluation() # check_cell_evaluation()
eval_all_cells()
from .eval_cells import eval_cells from .eval_cells import eval_cells, get_ignore_seg_ids
from .eval_nuclei import eval_nuclei from .eval_nuclei import eval_nuclei
...@@ -19,7 +19,7 @@ def eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius, ...@@ -19,7 +19,7 @@ def eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius,
bb = get_bounding_box(ds_ann) bb = get_bounding_box(ds_ann)
annotations = ds_ann[:] annotations = ds_ann[:]
seg = ds_seg[bb].squeeze() seg = ds_seg[bb].squeeze().astype('uint32')
assert annotations.shape == seg.shape assert annotations.shape == seg.shape
seg_eval = vigra.analysis.labelImageWithBackground(seg) seg_eval = vigra.analysis.labelImageWithBackground(seg)
...@@ -39,7 +39,7 @@ def eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius, ...@@ -39,7 +39,7 @@ def eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius,
def get_ignore_seg_ids(table_path, ignore_names=['cuticle', 'neuropil', 'yolk']): def get_ignore_seg_ids(table_path, ignore_names=['cuticle', 'neuropil', 'yolk']):
table = pd.read_csv(table_path) table = pd.read_csv(table_path, sep='\t')
ignore_seg_ids = [] ignore_seg_ids = []
for name in ignore_names: for name in ignore_names:
col = table[name].values.astype('uint8') col = table[name].values.astype('uint8')
...@@ -48,16 +48,25 @@ def get_ignore_seg_ids(table_path, ignore_names=['cuticle', 'neuropil', 'yolk']) ...@@ -48,16 +48,25 @@ def get_ignore_seg_ids(table_path, ignore_names=['cuticle', 'neuropil', 'yolk'])
return ignore_seg_ids return ignore_seg_ids
def to_scores(eval_res):
n = float(eval_res['n_annotations'] - eval_res['n_unmatched'])
n_splits = eval_res['n_splits']
n_merges = eval_res['n_merged_annotations']
return n_merges / n, n_splits / n
def eval_cells(seg_path, seg_key, def eval_cells(seg_path, seg_key,
annotation_path, annotation_key, annotation_path, annotation_key=None,
ignore_seg_ids=None, min_radius=16): ignore_seg_ids=None, min_radius=16):
""" Evaluate the cell segmentation. """ Evaluate the cell segmentation by computing
the percentage of falsely merged and split cell annotations
in manually annotated validation slices.
""" """
eval_res = {} eval_res = {}
with open_file(seg_path, 'r') as f_seg, open_file(annotation_path) as f_ann: with open_file(seg_path, 'r') as f_seg, open_file(annotation_path, 'r') as f_ann:
ds_seg = f_seg[seg_key] ds_seg = f_seg[seg_key]
g = f_ann[annotation_key] g = f_ann if annotation_key is None else f_ann[annotation_key]
def visit_annotation(name, node): def visit_annotation(name, node):
nonlocal eval_res nonlocal eval_res
...@@ -65,9 +74,11 @@ def eval_cells(seg_path, seg_key, ...@@ -65,9 +74,11 @@ def eval_cells(seg_path, seg_key,
print("Evaluating:", name) print("Evaluating:", name)
res = eval_slice(ds_seg, node, ignore_seg_ids, min_radius) res = eval_slice(ds_seg, node, ignore_seg_ids, min_radius)
eval_res = merge_evaluations(res, eval_res) eval_res = merge_evaluations(res, eval_res)
# for debugging
# print("current eval:", eval_res)
else: else:
print("Group:", name) print("Group:", name)
g.visititems(visit_annotation) g.visititems(visit_annotation)
return eval_res return to_scores(eval_res)
...@@ -65,7 +65,7 @@ def evaluate_annotations(seg, fg_annotations, bg_annotations, ...@@ -65,7 +65,7 @@ def evaluate_annotations(seg, fg_annotations, bg_annotations,
# unless we have overlap with a background annotation # unless we have overlap with a background annotation
# or are in the filter ids # or are in the filter ids
if this_labels.size == 0: if this_labels.size == 0:
if not has_bg_label and radii > min_radius: if not has_bg_label and radii[seg_id] > min_radius:
unmatched_ids.append(seg_id) unmatched_ids.append(seg_id)
# one label -> this seg-id seems to be well matched # one label -> this seg-id seems to be well matched
...@@ -80,17 +80,26 @@ def evaluate_annotations(seg, fg_annotations, bg_annotations, ...@@ -80,17 +80,26 @@ def evaluate_annotations(seg, fg_annotations, bg_annotations,
# increase the segment count # increase the segment count
n_segments += 1 n_segments += 1
# false splits = unmatched seg-ids and seg-ids corresponding to annotations that were matched # false splits = unmatched seg-ids and seg-ids corresponding to annotations
# more than once # that were matched more than once
# first, turn matched ids and labels into numpy arrays
matched_labels = list(matched_ids.values()) matched_labels = list(matched_ids.values())
matched_ids = np.array(list(matched_ids.keys()), dtype='uint32') matched_ids = np.array(list(matched_ids.keys()), dtype='uint32')
matched_labels, matched_counts = np.unique(matched_labels, return_counts=True)
# find the unique matched labels, their counts and the mapping to the original array
matched_labels, inv_mapping, matched_counts = np.unique(matched_labels, return_inverse=True, return_counts=True)
matched_counts = matched_counts[inv_mapping]
assert len(matched_counts) == len(matched_ids)
# combine unmatched ids and ids matched more than once
unmatched_ids = np.array(unmatched_ids, dtype='uint32')
false_split_ids = np.concatenate([unmatched_ids, matched_ids[matched_counts > 1]]) false_split_ids = np.concatenate([unmatched_ids, matched_ids[matched_counts > 1]])
# false merge annotations = overmatched ids # false merge annotations = overmatched ids
false_merge_ids = list(overmatched_ids.keys()) false_merge_ids = list(overmatched_ids.keys())
false_merge_labels = np.array([lab for overmatched in overmatched_ids false_merge_labels = np.array([lab for overmatched in overmatched_ids.values()
for lab in overmatched.values()], dtype='uint32') for lab in overmatched], dtype='uint32')
# find label ids that were not matched # find label ids that were not matched
all_matched = np.concatenate([matched_labels, false_merge_labels]) all_matched = np.concatenate([matched_labels, false_merge_labels])
...@@ -106,8 +115,10 @@ def evaluate_annotations(seg, fg_annotations, bg_annotations, ...@@ -106,8 +115,10 @@ def evaluate_annotations(seg, fg_annotations, bg_annotations,
'n_merged_ids': len(false_merge_ids), 'n_merged_ids': len(false_merge_ids),
'n_unmatched': len(unmatched_labels)} 'n_unmatched': len(unmatched_labels)}
ret = (metrics,) if not return_masks and not return_ids:
return metrics
ret = (metrics,)
if return_masks: if return_masks:
fs_mask = np.isin(seg, false_split_ids).astype('uint32') fs_mask = np.isin(seg, false_split_ids).astype('uint32')
fm_mask = np.isin(seg, false_merge_ids).astype('uint32') fm_mask = np.isin(seg, false_merge_ids).astype('uint32')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment