From 54026e2f6e356e1ec2ca649e717d282ed963abbb Mon Sep 17 00:00:00 2001
From: Constantin Pape <c.pape@gmx.net>
Date: Thu, 17 Oct 2019 16:10:24 +0200
Subject: [PATCH] Update cell seg evaluation

---
 analysis/validation/eval_cell_segmentation.py |  63 ++++++++++
 .../validation/eval_nucleus_segmentation.py   |   0
 scripts/segmentation/validation/eval_cells.py |   2 +-
 .../validation/refine_annotations.py          | 116 ++++++++++++++++++
 4 files changed, 180 insertions(+), 1 deletion(-)
 create mode 100644 analysis/validation/eval_cell_segmentation.py
 create mode 100644 analysis/validation/eval_nucleus_segmentation.py
 create mode 100644 scripts/segmentation/validation/refine_annotations.py

diff --git a/analysis/validation/eval_cell_segmentation.py b/analysis/validation/eval_cell_segmentation.py
new file mode 100644
index 0000000..b8f6470
--- /dev/null
+++ b/analysis/validation/eval_cell_segmentation.py
@@ -0,0 +1,63 @@
+# TODO shebang
+import argparse
+import os
+from scripts.segmentation.validation import eval_cells, get_ignore_seg_ids
+
+ANNOTATIONS = '../../data/rawdata/evaluation/validation_annotations.h5'
+BASELINES = '../../data/rawdata/evaluation/baseline_cell_segmentations.h5'
+
+
+def eval_seg(path, key, table):
+    ignore_ids = get_ignore_seg_ids(table)
+    fm, fs, tot = eval_cells(path, key, ANNOTATIONS,
+                             ignore_seg_ids=ignore_ids)
+    print("Evaluation yields:")
+    print("False merges:", fm)
+    print("False splits:", fs)
+    print("Total number of annotations:", tot)
+
+
+def eval_baselines():
+    names = ['lmc', 'mc', 'curated_lmc', 'curated_mc']
+    # TODO still need to compute region tables for the baselines
+    tables = ['',
+              '',
+              '',
+              '']
+    results = {}
+    for name, table in zip(names, tables):
+        ignore_ids = get_ignore_seg_ids(table)
+        fm, fs, tot = eval_cells(path, key, ANNOTATIONS,
+                                 ignore_seg_ids=ignore_ids)
+        results[name] = (fm, fs, tot)
+
+    for name in names:
+        print("Evaluation of", name, "yields:")
+        print("False merges:", fm)
+        print("False splits:", fs)
+        print("Total number of annotations:", tot)
+
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("path", type=str, help="Path to segmentation that should be validated.")
+    parser.add_argument("table", type=str, help="Path to table with region/semantic assignments")
+    parse.add_argument("--key", type=str, default="t00000/s00/0/cells", help="Segmentation key")
+    parser.add_argument("--baselines", type=int, default=0,
+                        help="Whether to evaluate the baseline segmentations (overrides path)")
+    args = parser.parse_args()
+
+    baselines = bool(args.baselines)
+    if baselines:
+        eval_baselines()
+    else:
+        path = args.path
+        table = args.table
+        key = args.key
+        assert os.path.exists(path), path
+        eval_seg(path, key, table)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/analysis/validation/eval_nucleus_segmentation.py b/analysis/validation/eval_nucleus_segmentation.py
new file mode 100644
index 0000000..e69de29
diff --git a/scripts/segmentation/validation/eval_cells.py b/scripts/segmentation/validation/eval_cells.py
index 908d407..91bd7d2 100644
--- a/scripts/segmentation/validation/eval_cells.py
+++ b/scripts/segmentation/validation/eval_cells.py
@@ -52,7 +52,7 @@ def to_scores(eval_res):
     n = float(eval_res['n_annotations'] - eval_res['n_unmatched'])
     n_splits = eval_res['n_splits']
     n_merges = eval_res['n_merged_annotations']
-    return n_merges / n, n_splits / n
+    return n_merges / n, n_splits / n, n
 
 
 def eval_cells(seg_path, seg_key,
diff --git a/scripts/segmentation/validation/refine_annotations.py b/scripts/segmentation/validation/refine_annotations.py
new file mode 100644
index 0000000..f58df90
--- /dev/null
+++ b/scripts/segmentation/validation/refine_annotations.py
@@ -0,0 +1,116 @@
+import os
+import json
+import vigra
+import numpy as np
+import napari
+from heimdall import view, to_source
+from elf.io import open_file
+
+from .eval_cells import get_bounding_box
+from .evaluate_annotations import evaluate_annotations
+
+
+def compute_masks(seg, labels, ignore_seg_ids):
+
+    seg_eval = vigra.analysis.labelImageWithBackground(seg)
+
+    if ignore_seg_ids is None:
+        this_ignore_ids = None
+    else:
+        ignore_mask = np.isin(seg, ignore_seg_ids)
+        this_ignore_ids = np.unique(seg_eval[ignore_mask])
+
+    fg_annotations = np.isin(labels, [1, 2]).astype('uint32')
+    bg_annotations = labels == 3
+
+    min_radius = 16
+    _, masks = evaluate_annotations(seg_eval, fg_annotations, bg_annotations,
+                                    this_ignore_ids, min_radius=min_radius,
+                                    return_masks=True)
+    return masks['merges'], masks['splits']
+
+
+def refine(seg_path, seg_key, ignore_seg_ids,
+           orientation, slice_id,
+           project_folder,
+           annotation_path='/g/arendt/...',
+           raw_path='/g/arendt/...',
+           raw_key='t00000/s00/1/cells'):
+
+    label_path = os.path.join(project_folder, 'labels.npy')
+    fm_path = os.path.join(project_folder, 'fm.npy')
+    fs_path = os.path.join(project_folder, 'fs.npy')
+    bb_path = os.path.join(project_folder, 'bounding_box.json')
+
+    if os.path.exists(project_folder):
+        print("Load from existing project")
+        labels = np.load(label_path) if os.path.exists(label_path) else None
+        fm = np.load(fm_path) if os.path.exists(fm_path) else None
+        fs = np.load(fs_path) if os.path.exists(fs_path) else None
+
+    with open_file(annotation_path, 'r') as fval:
+        ds = fval[orientation][str(slice_id)]
+        bb = get_bounding_box(ds)
+        ds.n_threads = 8
+        if labels is None:
+            labels = ds[:]
+
+    starts = [b.start for b in bb]
+    stops = [b.stop for b in bb]
+
+    with open_file(seg_path, 'r') as f:
+        ds = f[seg_key]
+        ds.n_threads = 8
+        seg = ds[bb].squeeze().astype('uint32')
+
+    with open_file(raw_path, 'r') as f:
+        ds = f[raw_key]
+        ds.n_threads = 8
+        raw = ds[bb].squeeze()
+
+    assert labels.shape == seg.shape
+    if fm is None:
+        assert fs is None
+        fm, fs = compute_masks(seg, labels, ignore_seg_ids)
+    else:
+        assert fs is not None
+
+    with napari.gui_qt():
+        viewer = view(to_source(raw, name='raw'), to_source(labels, name='labels'),
+                      to_source(seg, name='seg'), to_source(fm, name='merges'),
+                      to_source(fs, name='splits'), return_viewer=True)
+
+        @viewer.bind_key('s')
+        def save_labels(viewer):
+            print("Saving state ...")
+            layers = viewer.layers
+            os.makedirs(project_folder, exist_ok=True)
+
+            labels = layers['labels'].data
+            np.save(label_path, labels)
+
+            fm = layers['merges'].data
+            np.save(fm_path, fm)
+
+            fs = layers['splits'].data
+            np.save(fs_path, fs)
+
+            with open(bb_path, 'w') as f:
+                json.dump({'starts': starts, 'stops': stops}, f)
+            print("... done")
+
+
+def export_refined(project_folder, out_path, out_key):
+
+    print("Export", project_folder, "to", out_path, out_key)
+    label_path = os.path.join(project_folder, 'labels.npy')
+    labels = np.load(label_path)
+
+    bb_path = os.path.join(project_folder, 'bounding_box.json')
+    with open(bb_path) as f:
+        bb = json.load(f)
+
+    with open_file(out_path) as out:
+        dso = out.create_dataset(out_key, data=labels, compression='gzip')
+        dso.attrs['starts'] = bb['starts']
+        dso.attrs['stops'] = bb['stops']
-- 
GitLab