From 0e93379d8c0b62df913a6825eaa9bd5b1835e2b9 Mon Sep 17 00:00:00 2001
From: Constantin Pape <c.pape@gmx.net>
Date: Thu, 17 Oct 2019 20:50:07 +0200
Subject: [PATCH] Finish (initial) implementation of segmentation correction
 and validation.

---
 analysis/validation/check_validation.py       |  67 ---
 analysis/validation/eval_cell_segmentation.py |  63 ---
 .../validation/eval_nucleus_segmentation.py   |   0
 scripts/attributes/region_attributes.py       |   5 +-
 scripts/segmentation/correction/__init__.py   |   8 +
 .../correction/annotation_tool.py             | 213 +++++++++
 .../correction/assignment_diffs.py            |  54 +++
 .../correction/cillia_correction_tool.py      | 216 +++++++++
 .../correction/correct_false_merges.py        |  23 -
 .../correction/correction_tool.py             | 441 ++++++++++++++++++
 .../correction/export_node_labels.py          | 177 +++++++
 scripts/segmentation/correction/heuristics.py | 173 +++++++
 .../correction/mark_false_merges.py           |  20 -
 scripts/segmentation/correction/preprocess.py | 117 +++++
 .../segmentation/validation/eval_nuclei.py    |  74 ++-
 .../validation/evaluate_annotations.py        |   4 +-
 .../validation/refine_annotations.py          |  12 +-
 .../correction/correct_cell_segmentation.py   | 175 +++++++
 segmentation/correction/deprecated.py         |  47 ++
 segmentation/correction/fix_assignments.py    |  69 +++
 segmentation/correction/morphology_outlier.py |  53 +++
 segmentation/validation/check_validation.py   | 106 +++++
 .../validation/eval_cell_segmentation.py      |  94 ++++
 .../validation/eval_nucleus_segmentation.py   |  29 ++
 .../validation/proofread_annotations.py       |  29 ++
 25 files changed, 2082 insertions(+), 187 deletions(-)
 delete mode 100644 analysis/validation/check_validation.py
 delete mode 100644 analysis/validation/eval_cell_segmentation.py
 delete mode 100644 analysis/validation/eval_nucleus_segmentation.py
 create mode 100644 scripts/segmentation/correction/annotation_tool.py
 create mode 100644 scripts/segmentation/correction/assignment_diffs.py
 create mode 100644 scripts/segmentation/correction/cillia_correction_tool.py
 delete mode 100644 scripts/segmentation/correction/correct_false_merges.py
 create mode 100644 scripts/segmentation/correction/correction_tool.py
 create mode 100644 scripts/segmentation/correction/export_node_labels.py
 create mode 100644 scripts/segmentation/correction/heuristics.py
 delete mode 100644 scripts/segmentation/correction/mark_false_merges.py
 create mode 100644 scripts/segmentation/correction/preprocess.py
 create mode 100644 segmentation/correction/correct_cell_segmentation.py
 create mode 100644 segmentation/correction/deprecated.py
 create mode 100644 segmentation/correction/fix_assignments.py
 create mode 100644 segmentation/correction/morphology_outlier.py
 create mode 100644 segmentation/validation/check_validation.py
 create mode 100644 segmentation/validation/eval_cell_segmentation.py
 create mode 100644 segmentation/validation/eval_nucleus_segmentation.py
 create mode 100644 segmentation/validation/proofread_annotations.py

diff --git a/analysis/validation/check_validation.py b/analysis/validation/check_validation.py
deleted file mode 100644
index 3007aa6..0000000
--- a/analysis/validation/check_validation.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import h5py
-from heimdall import view, to_source
-from scripts.segmentation.validation import eval_cells, eval_nuclei, get_ignore_seg_ids
-
-
-def check_cell_evaluation():
-    from scripts.segmentation.validation.eval_cells import (eval_slice,
-                                                            get_bounding_box)
-
-    praw = '../../data/rawdata/sbem-6dpf-1-whole-raw.h5'
-    pseg = '../../data/0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5'
-    pann = '../../data/rawdata/evaluation/validation_annotations.h5'
-
-    table_path = '../../data/0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv'
-    ignore_seg_ids = get_ignore_seg_ids(table_path)
-
-    with h5py.File(pseg, 'r') as fseg, h5py.File(pann, 'r') as fann:
-        ds_seg = fseg['t00000/s00/0/cells']
-        ds_ann = fann['xy/1000']
-
-        print("Run evaluation ...")
-        res, masks = eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius=16,
-                                return_masks=True)
-        fm, fs = masks['merges'], masks['splits']
-        print()
-        print("Eval result")
-        print(res)
-        print()
-
-        print("Load raw data ...")
-        bb = get_bounding_box(ds_ann)
-        with h5py.File(praw, 'r') as f:
-            raw = f['t00000/s00/1/cells'][bb].squeeze()
-
-        print("Load seg data ...")
-        seg = ds_seg[bb].squeeze().astype('uint32')
-
-        view(to_source(raw, name='raw'), to_source(seg, name='seg'),
-             to_source(fm, name='merges'), to_source(fs, name='splits'))
-
-
-def eval_all_cells():
-    pseg = '../../data/0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5'
-    pann = '../../data/rawdata/evaluation/validation_annotations.h5'
-
-    table_path = '../../data/0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv'
-    ignore_seg_ids = get_ignore_seg_ids(table_path)
-
-    res = eval_cells(pseg, 't00000/s00/0/cells', pann,
-                     ignore_seg_ids=ignore_seg_ids)
-    print("Eval result:")
-    print(res)
-
-
-# TODO
-def check_nucleus_evaluation():
-    pass
-
-
-# TODO
-def eval_all_nulcei():
-    eval_nuclei()
-
-
-if __name__ == '__main__':
-    # check_cell_evaluation()
-    eval_all_cells()
diff --git a/analysis/validation/eval_cell_segmentation.py b/analysis/validation/eval_cell_segmentation.py
deleted file mode 100644
index b8f6470..0000000
--- a/analysis/validation/eval_cell_segmentation.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# TODO shebang
-import argparse
-import os
-from scripts.segmentation.validation import eval_cells, get_ignore_seg_ids
-
-ANNOTATIONS = '../../data/rawdata/evaluation/validation_annotations.h5'
-BASELINES = '../../data/rawdata/evaluation/baseline_cell_segmentations.h5'
-
-
-def eval_seg(path, key, table):
-    ignore_ids = get_ignore_seg_ids(table)
-    fm, fs, tot = eval_cells(path, key, ANNOTATIONS,
-                             ignore_seg_ids=ignore_ids)
-    print("Evaluation yields:")
-    print("False merges:", fm)
-    print("False splits:", fs)
-    print("Total number of annotations:", tot)
-
-
-def eval_baselines():
-    names = ['lmc', 'mc', 'curated_lmc', 'curated_mc']
-    # TODO still need to compute region tables for the baselines
-    tables = ['',
-              '',
-              '',
-              '']
-    results = {}
-    for name, table in zip(names, tables):
-        ignore_ids = get_ignore_seg_ids(table)
-        fm, fs, tot = eval_cells(path, key, ANNOTATIONS,
-                                 ignore_seg_ids=ignore_ids)
-        results[name] = (fm, fs, tot)
-
-    for name in names:
-        print("Evaluation of", name, "yields:")
-        print("False merges:", fm)
-        print("False splits:", fs)
-        print("Total number of annotations:", tot)
-
-
-
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("path", type=str, help="Path to segmentation that should be validated.")
-    parser.add_argument("table", type=str, help="Path to table with region/semantic assignments")
-    parse.add_argument("--key", type=str, default="t00000/s00/0/cells", help="Segmentation key")
-    parser.add_argument("--baselines", type=int, default=0,
-                        help="Whether to evaluate the baseline segmentations (overrides path)")
-    args = parser.parse_args()
-
-    baselines = bool(args.baselines)
-    if baselines:
-        eval_baselines()
-    else:
-        path = args.path
-        table = args.table
-        key = args.key
-        assert os.path.exists(path), path
-        eval_seg(path, key, table)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/analysis/validation/eval_nucleus_segmentation.py b/analysis/validation/eval_nucleus_segmentation.py
deleted file mode 100644
index e69de29..0000000
diff --git a/scripts/attributes/region_attributes.py b/scripts/attributes/region_attributes.py
index 950836b..ae9ea9f 100644
--- a/scripts/attributes/region_attributes.py
+++ b/scripts/attributes/region_attributes.py
@@ -63,9 +63,8 @@ def muscle_attributes(muscle_path, key_muscle,
 
 def region_attributes(seg_path, region_out,
                       image_folder, segmentation_folder,
-                      label_ids, tmp_folder, target, max_jobs):
-
-    key_seg = 't00000/s00/2/cells'
+                      label_ids, tmp_folder, target, max_jobs,
+                      key_seg='t00000/s00/2/cells'):
     key_tissue = 't00000/s00/0/cells'
 
     # 1.) compute the mapping to carved regions
diff --git a/scripts/segmentation/correction/__init__.py b/scripts/segmentation/correction/__init__.py
index e69de29..d83c810 100644
--- a/scripts/segmentation/correction/__init__.py
+++ b/scripts/segmentation/correction/__init__.py
@@ -0,0 +1,8 @@
+from .annotation_tool import AnnotationTool
+from .correction_tool import CorrectionTool
+from .cillia_correction_tool import CiliaCorrectionTool
+from .marker_tool import MarkerTool
+
+from .preprocess import preprocess
+from .export_node_labels import export_node_labels, to_paintera_format
+from .heuristics import rank_false_merges, get_ignore_ids
diff --git a/scripts/segmentation/correction/annotation_tool.py b/scripts/segmentation/correction/annotation_tool.py
new file mode 100644
index 0000000..a0cfa3b
--- /dev/null
+++ b/scripts/segmentation/correction/annotation_tool.py
@@ -0,0 +1,213 @@
+import os
+import json
+import queue
+import sys
+import threading
+
+import numpy as np
+import napari
+
+from heimdall import view, to_source
+from elf.io import open_file
+
+
+# Annotate segments
+class AnnotationTool:
+    n_threads = 3
+    queue_len = 6
+
+    def __init__(self, project_folder, id_path,
+                 table_path, table_key, scale_factor,
+                 raw_path, raw_key,
+                 ws_path, ws_key,
+                 node_label_path=None,
+                 node_label_key=None):
+        #
+        self.project_folder = project_folder
+        os.makedirs(self.project_folder, exist_ok=True)
+        self.processed_ids_file = os.path.join(project_folder, 'processed_ids.json')
+        self.annotation_path = os.path.join(project_folder, 'annotations.json')
+
+        self.id_path = id_path
+
+        self.table_path = table_path
+        self.table_key = table_key
+        self.scale_factor = scale_factor
+
+        self.raw_path = raw_path
+        self.raw_key = raw_key
+
+        self.ws_path = ws_path
+        self.ws_key = ws_key
+
+        self.node_label_path = node_label_path
+        self.node_label_key = node_label_key
+
+        self.load_all_data()
+        self.init_queue_and_workers()
+        print("Initialization done")
+
+    def load_all_data(self):
+        self.ds_raw = open_file(self.raw_path)[self.raw_key]
+        self.ds_ws = open_file(self.ws_path)[self.ws_key]
+        self.shape = self.ds_raw.shape
+        assert self.ds_ws.shape == self.shape
+
+        if self.node_label_path is None:
+            self.node_labels = None
+        else:
+            with open_file(self.node_label_path, 'r') as f:
+                self.node_labels = f[self.node_label_key][:]
+
+        with open(self.id_path) as f:
+            self.ids = np.array(json.load(f))
+
+        if os.path.exists(self.processed_ids_file):
+            with open(self.processed_ids_file) as f:
+                self.processed_ids = json.load(f)
+        else:
+            self.processed_ids = []
+
+        already_processed = np.in1d(self.ids, self.processed_ids)
+        missing_ids = self.ids[~already_processed]
+
+        if os.path.exists(self.annotation_path):
+            with open(self.annotation_path) as f:
+                self.annotations = json.load(f)
+        else:
+            self.annotations = {}
+
+        self.next_queue = queue.Queue()
+        for mi in missing_ids:
+            self.next_queue.put_nowait(mi)
+
+        # morphology table entries
+        # id (1)
+        # size (1)
+        # com (3)
+        # bb-min (3)
+        # bb-max (3)
+        with open_file(self.table_path, 'r') as f:
+            table = f[self.table_key][:]
+        self.bb_starts = table[:, 5:8]
+        self.bb_stops = table[:, 8:]
+        self.bb_starts /= self.scale_factor
+        self.bb_stops /= self.scale_factor
+
+    def load_segment(self, seg_id):
+
+        # get bounding box of this segment
+        starts, stops = self.bb_starts[seg_id], self.bb_stops[seg_id]
+        halo = [2, 2, 2]
+        bb = tuple(slice(max(0, int(sta - ha)),
+                         min(sh, int(sto + ha)))
+                   for sta, sto, ha, sh in zip(starts, stops, halo, self.shape))
+
+        # load raw and watershed
+        raw = self.ds_raw[bb]
+        ws = self.ds_ws[bb]
+
+        # make the segment mask
+        if self.node_labels is None:
+            seg_mask = (ws == seg_id).astype('uint32')
+            ws = None
+        else:
+            node_ids = np.where(self.node_labels == seg_id)[0].astype('uint64')
+            seg_mask = np.isin(ws, node_ids)
+            ws[~seg_mask] = 0
+
+        return raw, ws, seg_mask.astype('uint32')
+
+    def worker_thread(self):
+        while not self.next_queue.empty():
+            seg_id = self.next_queue.get()
+            print("Loading seg", seg_id)
+            qitem = self.load_segment(seg_id)
+            self.queue.put((seg_id, qitem))
+
+    def init_queue_and_workers(self):
+        self.queue = queue.Queue(maxsize=self.queue_len)
+        for i in range(self.n_threads):
+            t = threading.Thread(name='worker-%i' % i, target=self.worker_thread)
+            t.setDaemon(True)
+            t.start()
+
+    def correct_segment(self, seg_id, qitem):
+        print("Processing segment:", seg_id)
+        raw, ws, seg = qitem
+        annotation = None
+
+        with napari.gui_qt():
+            if ws is None:
+                viewer = view(to_source(raw, name='raw'),
+                              to_source(seg, name='seg'),
+                              return_viewer=True)
+            else:
+                viewer = view(to_source(raw, name='raw'),
+                              to_source(ws, name='ws'),
+                              to_source(seg, name='seg'),
+                              return_viewer=True)
+
+            @viewer.bind_key('h')
+            def print_help(viewer):
+                print("The following annotations are available:")
+                print("[r] - needs to be revisited")
+                print("[i] - incomplete, needs to be merged with other id")
+                print("[m] - merge with background")
+                print("[y] - confirm segment")
+                print("[c] - enter custom annotation")
+                print("Other options:")
+                print("[q] - quit")
+
+            @viewer.bind_key('r')
+            def revisit(viewer):
+                nonlocal annotation
+                print("Setting annotation to revisit")
+                annotation = 'revisit'
+
+            @viewer.bind_key('m')
+            def merge(viewer):
+                nonlocal annotation
+                print("Setting annotation to merge")
+                annotation = 'merge'
+
+            @viewer.bind_key('i')
+            def incomplete(viewer):
+                nonlocal annotation
+                print("Setting annotation to incomplete")
+                annotation = 'incomplete'
+
+            @viewer.bind_key('y')
+            def confirm(viewer):
+                nonlocal annotation
+                print("Setting annotation to confirm")
+                annotation = 'confirm'
+
+            @viewer.bind_key('c')
+            def custom(viewer):
+                nonlocal annotation
+                annotation = input("Enter custom annotation")
+
+            @viewer.bind_key('q')
+            def quit(viewer):
+                self.save_annotation(seg_id, annotation)
+                sys.exit(0)
+
+        self.save_annotation(seg_id, annotation)
+
+    def save_annotation(self, seg_id, annotation):
+        self.processed_ids.append(int(seg_id))
+        with open(self.processed_ids_file, 'w') as f:
+            json.dump(self.processed_ids, f)
+        if annotation is None:
+            return
+        self.annotations[int(seg_id)] = annotation
+        with open(self.annotation_path, 'w') as f:
+            json.dump(self.annotations, f)
+
+    def __call__(self):
+        left_to_process = len(self.ids) - len(self.processed_ids)
+        while left_to_process > 0:
+            seg_id, qitem = self.queue.get()
+            self.correct_segment(seg_id, qitem)
+            left_to_process = len(self.ids) - len(self.processed_ids)
diff --git a/scripts/segmentation/correction/assignment_diffs.py b/scripts/segmentation/correction/assignment_diffs.py
new file mode 100644
index 0000000..c85b046
--- /dev/null
+++ b/scripts/segmentation/correction/assignment_diffs.py
@@ -0,0 +1,54 @@
+import json
+import os
+
+import numpy as np
+import luigi
+import nifty.ground_truth as ngt
+from cluster_tools.node_labels import NodeLabelWorkflow
+
+# TODO some of this is general purpose and shoulg go to elf
+
+
+def node_labels(ws_path, ws_key, input_path, input_key,
+                output_path, output_key, prefix,
+                tmp_folder, target='slurm', max_jobs=250):
+    task = NodeLabelWorkflow
+
+    configs = task.get_config()
+    config_folder = os.path.join(tmp_folder, 'configs')
+    os.makedirs(config_folder, exist_ok=True)
+
+    conf = configs['global']
+    shebang = '/g/kreshuk/pape/Work/software/conda/miniconda3/envs/cluster_env37/bin/python3.7'
+    conf['shebang'] = shebang
+    conf['block_shape'] = [16, 256, 256]
+
+    with open(os.path.join(config_folder, 'global.config'), 'w') as f:
+        json.dump(conf, f)
+
+    conf = configs['block_node_labels']
+    conf.update({'mem_limit': 4, 'time_limit': 120})
+    with open(os.path.join(config_folder, 'block_node_labels.config'), 'w') as f:
+        json.dump(conf, f)
+
+    t = task(tmp_folder=tmp_folder, config_dir=config_folder,
+             max_jobs=max_jobs, target=target,
+             ws_path=ws_path, ws_key=ws_key,
+             input_path=input_path, input_key=input_key,
+             output_path=output_path, output_key=output_key,
+             prefix=prefix)
+    luigi.build([t], local_scheduler=True)
+
+
+# we only look at objects in the reference assignment that are SPLIT in
+# the new assignment. For the other direction, just reverse reference and new
+def assignment_diff_splits(reference_assignment, new_assignment):
+    assert reference_assignment.shape == new_assignment.shape
+    reference_ids = np.unique(reference_assignment)
+
+    print("Computing assignment diff ...")
+    ovlp_comp = ngt.overlap(reference_assignment, new_assignment)
+    split_ids = [ovlp_comp.overlapArrays(ref_id)[0] for ref_id in reference_ids]
+    split_ids = {int(ref_id): len(ovlps) for ref_id, ovlps in zip(reference_ids, split_ids)
+                 if len(ovlps) > 1}
+    return split_ids
diff --git a/scripts/segmentation/correction/cillia_correction_tool.py b/scripts/segmentation/correction/cillia_correction_tool.py
new file mode 100644
index 0000000..7c0bd7c
--- /dev/null
+++ b/scripts/segmentation/correction/cillia_correction_tool.py
@@ -0,0 +1,216 @@
+# initially:
+# go over all cilia, load the volume and highlight the cell the cilium was mapped to if applicable
+
+import json
+import os
+import queue
+import sys
+import threading
+
+import numpy as np
+import pandas as pd
+import napari
+
+from heimdall import view, to_source
+from elf.io import open_file
+from scripts.files.xml_utils import get_h5_path_from_xml
+
+
+def xml_to_h5_path(xml_path):
+    path = get_h5_path_from_xml(xml_path, return_absolute_path=True)
+    return path
+
+
+class CiliaCorrectionTool:
+    n_threads = 1
+    queue_len = 2
+
+    def __init__(self, project_folder,
+                 version_folder, scale, cilia_cell_table):
+        self.project_folder = project_folder
+        os.makedirs(self.project_folder, exist_ok=True)
+        self.processed_id_file = os.path.join(self.project_folder, 'processed_ids.json')
+
+        assert os.path.exists(version_folder)
+
+        raw_path = os.path.join(version_folder, 'images', 'sbem-6dpf-1-whole-raw.xml')
+        self.raw_path = xml_to_h5_path(raw_path)
+
+        cilia_seg_path = os.path.join(version_folder, 'segmentations',
+                                      'sbem-6dpf-1-whole-segmented-cilia-labels.xml')
+        self.cilia_seg_path = xml_to_h5_path(cilia_seg_path)
+
+        cell_seg_path = os.path.join(version_folder, 'segmentations',
+                                     'sbem-6dpf-1-whole-segmented-cells-labels.xml')
+        self.cell_seg_path = xml_to_h5_path(cell_seg_path)
+
+        self.cilia_table_path = os.path.join(version_folder, 'tables',
+                                             'sbem-6dpf-1-whole-segmented-cilia-labels', 'default.csv')
+        self.cilia_cell_table = pd.read_csv(cilia_cell_table, sep='\t')
+
+        self.scale = scale
+        self.init_data()
+        self.init_queue_and_workers()
+
+    def init_data(self):
+        # init the bounding boxes
+        table = pd.read_csv(self.cilia_table_path, sep='\t')
+        bb_start = table[['bb_min_z', 'bb_min_y', 'bb_min_x']].values.astype('float32')
+        bb_start[np.isinf(bb_start)] = 0
+        bb_stop = table[['bb_max_z', 'bb_max_y', 'bb_max_x']].values.astype('float32')
+
+        resolution = [.025, .01, .01]
+        scale_factor = [2 ** max(0, self.scale - 1)] + 2 * [2 ** self.scale]
+        resolution = [res * sf for res, sf in zip(resolution, scale_factor)]
+
+        halo = [4, 32, 32]
+        self.bbs = [tuple(slice(int(sta / res - ha), int(sto / res + ha))
+                          for sta, sto, res, ha in zip(start, stop, resolution, halo))
+                    for start, stop in zip(bb_start, bb_stop)]
+
+        # mapping from cilia ids to seg_ids
+        cil_map_ids = self.cilia_cell_table['cilia_id'].values
+        cell_map_ids = self.cilia_cell_table['cell_id'].values
+        self.id_mapping = {cil_id: cell_id for cil_id, cell_id in zip(cil_map_ids, cell_map_ids)}
+
+        # init the relevant ids
+        self.cilia_ids = np.arange(len(self.bbs))
+        if os.path.exists(self.processed_id_file):
+            with open(self.processed_id_file) as f:
+                self.processed_id_map = json.load(f)
+        else:
+            self.processed_id_map = {}
+        self.processed_id_map = {int(k): v for k, v in self.processed_id_map.items()}
+        self.processed_ids = list(self.processed_id_map.keys())
+
+        already_processed = np.in1d(self.cilia_ids, self.processed_ids)
+        missing_ids = self.cilia_ids[~already_processed]
+
+        # fill the queue
+        self.next_queue = queue.Queue()
+        for mi in missing_ids:
+            # if mi in (0, 1):
+            #     continue
+            self.next_queue.put_nowait(mi)
+
+    def worker_thread(self):
+        while not self.next_queue.empty():
+            seg_id = self.next_queue.get()
+            print("Loading seg", seg_id)
+            qitem = self.load_data(seg_id)
+            self.queue.put((seg_id, qitem))
+
+    def init_queue_and_workers(self):
+        self.queue = queue.Queue(maxsize=self.queue_len)
+        for i in range(self.n_threads):
+            t = threading.Thread(name='worker-%i' % i, target=self.worker_thread)
+            # t.setDaemon(True)
+            t.start()
+        save_folder = os.path.join(self.project_folder, 'results')
+        os.makedirs(save_folder, exist_ok=True)
+
+    def load_data(self, cil_id):
+        if cil_id in (0, 1):
+            return None
+        cell_seg_key = 't00000/s00/%i/cells' % (self.scale - 1,)
+        cil_seg_key = 't00000/s00/%i/cells' % (self.scale + 1,)
+        raw_key = 't00000/s00/%i/cells' % self.scale
+
+        bb = self.bbs[cil_id]
+        with open_file(self.raw_path, 'r') as f:
+            ds = f[raw_key]
+            raw = ds[bb]
+
+        with open_file(self.cilia_seg_path, 'r') as f:
+            ds = f[cil_seg_key]
+            cil_seg = ds[bb].astype('uint32')
+        cil_mask = cil_seg == cil_id
+        cil_mask = 2 * cil_mask.astype('uint32')
+
+        cell_id = self.id_mapping[cil_id]
+        if cell_id in (0, np.nan):
+            cell_seg = None
+        else:
+
+            with open_file(self.cell_seg_path, 'r') as f:
+                ds = f[cell_seg_key]
+                cell_seg = ds[bb].astype('uint32')
+                cell_seg = (cell_seg == cell_id).astype('uint32')
+
+        return raw, cil_seg, cil_mask, cell_seg
+
+    def __call__(self):
+        left_to_process = len(self.cilia_ids) - len(self.processed_ids)
+        print("Left to process:", left_to_process)
+        while left_to_process > 0:
+            seg_id, qitem = self.queue.get()
+            self.correct_segment(seg_id, qitem)
+            left_to_process = len(self.cilia_ids) - len(self.processed_ids)
+
+    def correct_segment(self, seg_id, qitem):
+        if qitem is None:
+            return
+
+        print("Processing cilia:", seg_id)
+        raw, cil_seg, cil_mask, cell_seg = qitem
+
+        with napari.gui_qt():
+            if cell_seg is None:
+                viewer = view(to_source(raw, name='raw'), to_source(cil_seg, name='cilia-segmentation'),
+                              to_source(cil_mask, name='cilia-mask'), return_viewer=True)
+            else:
+                viewer = view(to_source(raw, name='raw'), to_source(cil_seg, name='cilia-segmentation'),
+                              to_source(cil_mask, name='cilia-mask'), to_source(cell_seg, name='cell-segmentation'),
+                              return_viewer=True)
+
+            @viewer.bind_key('c')
+            def confirm(viewer):
+                print("Confirming the current id", seg_id, "as correct")
+                self.processed_id_map[int(seg_id)] = 'correct'
+
+            @viewer.bind_key('b')
+            def background(viewer):
+                print("Confirming the current id", seg_id, "into background")
+                self.processed_id_map[int(seg_id)] = 'background'
+
+            @viewer.bind_key('m')
+            def merge(viewer):
+                print("Merging the current id", seg_id, "with other cilia")
+                valid_input = False
+                while not valid_input:
+                    merge_id = input("Please enter the merge id:")
+                    try:
+                        merge_id = int(merge_id)
+                        valid_input = True
+                    except ValueError:
+                        valid_input = False
+                        print("You have entered an invalid input", merge_id, "please try again")
+                self.processed_id_map[int(seg_id)] = merge_id
+
+            @viewer.bind_key('r')
+            def revisit(viewer):
+                print("Marking the current id", seg_id, "to be revisited because something is off")
+                self.processed_id_map[int(seg_id)] = 'revisit'
+
+            @viewer.bind_key('h')
+            def print_help(viewer):
+                print("[c] - confirm cilia as correct")
+                print("[b] - mark cilia as background")
+                print("[m] - merge cilia with other cilia id")
+                print("[d] - revisit this cilia")
+                print("[q] - quit")
+
+            # save progress and sys.exit
+            @viewer.bind_key('q')
+            def quit(viewer):
+                print("Quit correction tool")
+                self.save_result(seg_id)
+                sys.exit(0)
+
+        # save the results for this segment
+        self.save_result(seg_id)
+
+    def save_result(self, seg_id):
+        self.processed_ids.append(seg_id)
+        with open(self.processed_id_file, 'w') as f:
+            json.dump(self.processed_id_map, f)
diff --git a/scripts/segmentation/correction/correct_false_merges.py b/scripts/segmentation/correction/correct_false_merges.py
deleted file mode 100644
index 59e9e35..0000000
--- a/scripts/segmentation/correction/correct_false_merges.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import os
-import napari
-from elf.io import open_file
-
-
-# TODO
-# correct false merges via lifted multicut (or watershed)
-# load raw data, watershed from bounding box for correction id
-# load sub-graph for nodes corresponding to this segment
-# (this takes long, so preload 5 or so via daemon process)
-# add layer for seeds, resolve the segment via lmc (or watershed)
-# once happy, store the new ids and move on to the next
-class CorrectFalseMerges:
-    def __init__(self, project_folder,
-                 correct_id_path=None, table_path=None,
-                 raw_path=None, raw_key=None,
-                 ws_path=None, ws_key=None,
-                 node_label_path=None, node_label_key=None,
-                 problem_path=None, graph_key=None, cost_key=None):
-        pass
-
-    def __call__(self):
-        pass
diff --git a/scripts/segmentation/correction/correction_tool.py b/scripts/segmentation/correction/correction_tool.py
new file mode 100644
index 0000000..9f36a84
--- /dev/null
+++ b/scripts/segmentation/correction/correction_tool.py
@@ -0,0 +1,441 @@
+import os
+import json
+import queue
+import sys
+import threading
+
+import numpy as np
+import napari
+import nifty
+import nifty.distributed as ndist
+import nifty.tools as nt
+import vigra
+
+from heimdall import view, to_source
+from elf.io import open_file
+
+
+# correct false merges via lifted multicut (or watershed)
+# load raw data, watershed from bounding box for correction id
+# load sub-graph for nodes corresponding to this segment
+# (this takes long, so preload 5 or so via daemon process)
+# add layer for seeds, resolve the segment via lmc (or watershed)
+# once happy, store the new ids and move on to the next
+class CorrectionTool:
+    n_threads = 3
+    queue_len = 3
+
+    def __init__(self, project_folder,
+                 false_merge_id_path=None,
+                 table_path=None, table_key=None, scale_factor=None,
+                 raw_path=None, raw_key=None,
+                 seg_path=None, seg_key=None,
+                 ws_path=None, ws_key=None,
+                 node_label_path=None, node_label_key=None,
+                 problem_path=None, graph_key=None, feat_key=None,
+                 load_lazy=False):
+        #
+        self.project_folder = project_folder
+        self.config_file = os.path.join(project_folder, 'correct_false_merges_config.json')
+        self.processed_ids_file = os.path.join(project_folder, 'processed_ids.json')
+        self.annotation_path = os.path.join(project_folder, 'annotations.json')
+
+        if os.path.exists(self.config_file):
+            self.read_config()
+            serialize_config = False
+        else:
+            assert false_merge_id_path is not None
+            self.false_merge_id_path = false_merge_id_path
+
+            assert table_path is not None
+            self.table_path = table_path
+            assert table_key is not None
+            self.table_key = table_key
+            assert scale_factor is not None
+            self.scale_factor = scale_factor
+
+            assert raw_path is not None
+            self.raw_path = raw_path
+            assert raw_key is not None
+            self.raw_key = raw_key
+
+            assert seg_path is not None
+            self.seg_path = seg_path
+            assert seg_key is not None
+            self.seg_key = seg_key
+
+            assert ws_path is not None
+            self.ws_path = ws_path
+            assert ws_key is not None
+            self.ws_key = ws_key
+
+            assert node_label_path is not None
+            self.node_label_path = node_label_path
+            assert node_label_key is not None
+            self.node_label_key = node_label_key
+
+            assert problem_path is not None
+            self.problem_path = problem_path
+            assert graph_key is not None
+            self.graph_key = graph_key
+            assert feat_key is not None
+            self.feat_key = feat_key
+
+            serialize_config = True
+
+        # TODO implement lazy loading
+        self.load_lazy = load_lazy
+        print("Loading graph, weights and node labels ...")
+        self.load_all_data()
+        print("... done")
+
+        print("Initializing queues ...")
+        self.init_queue_and_workers()
+        print("... done")
+
+        if serialize_config:
+            self.write_config()
+        print("Initialization done")
+
+    def read_config(self):
+        with open(self.config_file) as f:
+            conf = json.load(f)
+
+        self.false_merge_id_path = conf['false_merge_id_path']
+        self.table_path = conf['table_path']
+        self.table_key = conf['table_key']
+        self.scale_factor = conf['scale_factor']
+
+        self.raw_path, self.raw_key = conf['raw_path'], conf['raw_key']
+        self.seg_path, self.seg_key = conf['seg_path'], conf['seg_key']
+        self.ws_path, self.ws_key = conf['ws_path'], conf['ws_key']
+        self.node_label_path, self.node_label_key = conf['node_label_path'], conf['node_label_key']
+
+        self.problem_path = conf['problem_path']
+        self.graph_key = conf['graph_key']
+        self.feat_key = conf['feat_key']
+
+    def write_config(self):
+        os.makedirs(self.project_folder, exist_ok=True)
+        conf = {'false_merge_id_path': self.false_merge_id_path,
+                'table_path': self.table_path, 'table_key': self.table_key,
+                'scale_factor': self.scale_factor,
+                'raw_path': self.raw_path, 'raw_key': self.raw_key,
+                'seg_path': self.seg_path, 'seg_key': self.seg_key,
+                'ws_path': self.ws_path, 'ws_key': self.ws_key,
+                'node_label_path': self.node_label_path, 'node_label_key': self.node_label_key,
+                'problem_path': self.problem_path, 'graph_key': self.graph_key,
+                'feat_key': self.feat_key}
+        with open(self.config_file, 'w') as f:
+            json.dump(conf, f)
+
+    def load_all_data(self):
+        self.ds_raw = open_file(self.raw_path)[self.raw_key]
+        self.ds_seg = open_file(self.seg_path)[self.seg_key]
+        self.ds_ws = open_file(self.ws_path)[self.ws_key]
+        self.shape = self.ds_raw.shape
+        assert self.ds_ws.shape == self.shape
+
+        with open_file(self.node_label_path, 'r') as f:
+            self.node_labels = f[self.node_label_key][:]
+
+        with open(self.false_merge_id_path) as f:
+            self.false_merge_ids = np.array(json.load(f))
+
+        if os.path.exists(self.processed_ids_file):
+            with open(self.processed_ids_file) as f:
+                self.processed_ids = json.load(f)
+        else:
+            self.processed_ids = []
+
+        already_processed = np.in1d(self.false_merge_ids, self.processed_ids)
+        missing_ids = self.false_merge_ids[~already_processed]
+
+        if os.path.exists(self.annotation_path):
+            with open(self.annotation_path) as f:
+                self.annotations = json.load(f)
+        else:
+            self.annotations = {}
+
+        self.next_queue = queue.Queue()
+        for mi in missing_ids:
+            self.next_queue.put_nowait(mi)
+
+        self.graph = ndist.Graph(self.problem_path, self.graph_key, self.n_threads)
+        self.uv_ids = self.graph.uvIds()
+        with open_file(self.problem_path, 'r') as f:
+            ds = f[self.feat_key]
+            ds.n_threads = self.n_threads
+
+            self.probs = ds[:, 0]
+
+        # morphology table entries
+        # id (1)
+        # size (1)
+        # com (3)
+        # bb-min (3)
+        # bb-max (3)
+        with open_file(self.table_path, 'r') as f:
+            table = f[self.table_key][:]
+        self.bb_starts = table[:, 5:8]
+        self.bb_stops = table[:, 8:]
+        self.bb_starts /= self.scale_factor
+        self.bb_stops /= self.scale_factor
+
+    def load_subgraph(self, node_ids):
+        inner_edges, _ = self.graph.extractSubgraphFromNodes(node_ids, allowInvalidNodes=True)
+        nodes_relabeled, max_id, mapping = vigra.analysis.relabelConsecutive(node_ids,
+                                                                             start_label=0,
+                                                                             keep_zeros=False)
+        uv_ids = self.uv_ids[inner_edges]
+        uv_ids = nt.takeDict(mapping, uv_ids)
+        n_nodes = max_id + 1
+        graph = nifty.graph.undirectedGraph(n_nodes)
+        graph.insertEdges(uv_ids)
+
+        probs = self.probs[inner_edges]
+        assert len(probs) == graph.numberOfEdges
+        return graph, probs, mapping
+
+    def load_segment(self, seg_id):
+
+        # get bounding box of this segment
+        starts, stops = self.bb_starts[seg_id], self.bb_stops[seg_id]
+        halo = [2, 2, 2]
+        bb = tuple(slice(max(0, int(sta - ha)),
+                         min(sh, int(sto + ha)))
+                   for sta, sto, ha, sh in zip(starts, stops, halo, self.shape))
+
+        # extract the sub-graph
+        node_ids = np.where(self.node_labels == seg_id)[0].astype('uint64')
+        graph, probs, mapping = self.load_subgraph(node_ids)
+        # graph, probs, mapping = None, None, None
+
+        # load raw and watershed
+        raw = self.ds_raw[bb]
+        try:
+            seg = self.ds_seg[bb]
+        except RuntimeError:
+            seg = None
+        ws = self.ds_ws[bb]
+
+        # make the segment mask
+        ws[~np.isin(ws, node_ids)] = 0
+        if seg is None:
+            seg_mask = ws > 0
+        else:
+            seg_mask = seg == seg_id
+
+        return raw, ws, seg_mask.astype('uint32'), graph, probs, mapping
+
+    def worker_thread(self):
+        while not self.next_queue.empty():
+            seg_id = self.next_queue.get()
+            print("Loading seg", seg_id)
+            qitem = self.load_segment(seg_id)
+            self.queue.put((seg_id, qitem))
+
+    def init_queue_and_workers(self):
+        self.queue = queue.Queue(maxsize=self.queue_len)
+        for i in range(self.n_threads):
+            t = threading.Thread(name='worker-%i' % i, target=self.worker_thread)
+            t.setDaemon(True)
+            t.start()
+        save_folder = os.path.join(self.project_folder, 'results')
+        os.makedirs(save_folder, exist_ok=True)
+
+    @staticmethod
+    def graph_watershed(graph, probs, ws, seed_points, mapping):
+        seed_ids = np.unique(seed_points)[1:]
+        if len(seed_ids) == 0:
+            return None
+
+        seeds = np.zeros(graph.numberOfNodes, dtype='uint64')
+        for seed_id in seed_ids:
+            mask = seed_points == seed_id
+            seed_nodes = np.unique(ws[mask])
+            if seed_nodes[0] == 0:
+                seed_nodes = seed_nodes[1:]
+            seed_nodes = nt.takeDict(mapping, seed_nodes)
+            seeds[seed_nodes] = seed_id
+
+        node_labels = nifty.graph.edgeWeightedWatershedsSegmentation(graph, seeds, probs)
+        return node_labels
+
+    def correct_segment(self, seg_id, qitem):
+        print("Processing segment:", seg_id)
+        raw, ws, seg, graph, probs, mapping = qitem
+        ws_ids = np.unique(ws)[1:]
+        node_labels = np.ones(graph.numberOfNodes, dtype='uint64')
+
+        seeds = np.zeros_like(ws, dtype='uint32')
+        skip_this_segment = False
+
+        with napari.gui_qt():
+            viewer = view(to_source(raw, name='raw'), to_source(ws, name='ws'),
+                          to_source(seg, name='seg'), to_source(seeds, name='seeds'),
+                          return_viewer=True)
+
+            @viewer.bind_key('h')
+            def print_help(viewer):
+                print("[w] - run watershed from seeds")
+                print("[s] - skip the current segment (if it's not a merge)")
+                print("[c] - clear seeds")
+                print("[d] - save current data for debugging")
+                print("[a] - annotate")
+                print("[q] - quit")
+
+            @viewer.bind_key('a')
+            def annotatate(viewer):
+                msg = """Annotating seg id, choose one of [c] (correct) [s] (revisit at lower scale), [m] (merge to background).
+                         Otherwise, you can add a custom annotation."""
+                inp = input(msg)
+                if inp == 'c':
+                    annotation = 'correct'
+                elif inp == 's':
+                    annotation = 'revisit'
+                elif inp == 'm':
+                    annotation = 'merge'
+                else:
+                    annotation = input("Enter custom annotation.")
+                self.annotations[int(seg_id)] = annotation
+                with open(self.annotation_path, 'w') as f:
+                    json.dump(self.annotations, f)
+
+            @viewer.bind_key('s')
+            def skip(viewer):
+                print("Skipping the current segment")
+                nonlocal skip_this_segment
+                skip_this_segment = True
+                # TODO quit viewer
+
+            @viewer.bind_key('d')
+            def debug(viewer):
+                print("Saving debug data ...")
+                debug_folder = os.path.join(self.project_folder, 'debug')
+                os.makedirs(debug_folder, exist_ok=True)
+
+                layers = viewer.layers
+                seed_points = layers['seeds'].data
+                with open_file(os.path.join(debug_folder, 'data.n5')) as f:
+                    f.create_dataset('raw', data=raw, compression='gzip')
+                    f.create_dataset('ws', data=ws, compression='gzip')
+                    f.create_dataset('seeds', data=seed_points, compression='gzip')
+
+                with open(os.path.join(debug_folder, 'mapping.json'), 'w') as f:
+                    json.dump(mapping, f)
+                np.save(os.path.join(debug_folder, 'graph.npy'), graph.uvIds())
+                np.save(os.path.join(debug_folder, 'probs.npy'), probs)
+                print("... done")
+
+            @viewer.bind_key('w')
+            def watershed(viewer):
+                nonlocal node_labels, seeds
+
+                print("Run watershed from seed layer ...")
+                layers = viewer.layers
+                ws = layers['ws'].data
+                seeds = layers['seeds'].data
+
+                new_node_labels = self.graph_watershed(graph, probs, ws, seeds, mapping)
+                if new_node_labels is None:
+                    print("Did not find any seeds, doing nothing")
+                    return
+                else:
+                    node_labels = new_node_labels
+
+                label_dict = {wsid: node_labels[mapping[wsid]] for wsid in ws_ids}
+                label_dict[0] = 0
+                seg = nt.takeDict(label_dict, ws)
+
+                layers['seg'].data = seg
+                print("... done")
+
+            @viewer.bind_key('c')
+            def clear(viewer):
+                nonlocal node_labels, seeds
+                print("Clear seeds ...")
+                confirm = input("Do you really want to clean the seeds? y / [n]")
+                if confirm != 'y':
+                    return
+                node_labels = np.ones(graph.numberOfNodes, dtype='uint64')
+                seeds = np.zeros_like(ws, dtype='uint32')
+                viewer.layers['seeds'].data = seeds
+                seg = (ws > 0).astype('uint32')
+                viewer.layers['seg'].data = seg
+                print("... done")
+
+            # save progress and sys.exit
+            @viewer.bind_key('q')
+            def quit(viewer):
+                print("Quit correction tool")
+                self.save_segment_result(seg_id, ws, seeds,
+                                         node_labels, mapping, skip_this_segment)
+                sys.exit(0)
+
+        # save the results for this segment
+        self.save_segment_result(seg_id, ws, seeds, node_labels, mapping, skip_this_segment)
+
+    def save_segment_result(self, seg_id, ws, seeds, node_labels, mapping, skip):
+        if not skip:
+            save_file = os.path.join(self.project_folder, 'results', '%i.npz' % seg_id)
+            node_ids = list(mapping.keys())
+            save_labels = [node_labels[mapping[nid]] for nid in node_ids]
+
+            seed_ids = np.unique(seeds[1:])
+            seeded_ids = []
+            seed_labels = []
+            for seed_id in seed_ids:
+                mask = seeds == seed_id
+                this_ids = np.unique(ws[mask])
+                if this_ids[0] == 0:
+                    this_ids = this_ids[1:]
+                seeded_ids.extend(this_ids)
+                seed_labels.extend(len(this_ids) * [seed_id])
+
+            np.savez_compressed(save_file, node_ids=node_ids, node_labels=save_labels,
+                                seeded_ids=seeded_ids, seed_labels=seed_labels)
+        self.processed_ids.append(int(seg_id))
+        with open(self.processed_ids_file, 'w') as f:
+            json.dump(self.processed_ids, f)
+
+    def __call__(self):
+        left_to_process = len(self.false_merge_ids) - len(self.processed_ids)
+        while left_to_process > 0:
+            seg_id, qitem = self.queue.get()
+            self.correct_segment(seg_id, qitem)
+            left_to_process = len(self.false_merge_ids) - len(self.processed_ids)
+
+    @staticmethod
+    def debug(debug_folder, n_threads=8):
+        with open_file(os.path.join(debug_folder, 'data.n5')) as f:
+            ds = f['raw']
+            ds.n_threads = n_threads
+            raw = ds[:]
+
+            ds = f['ws']
+            ds.n_threads = n_threads
+            ws = ds[:]
+
+            ds = f['seeds']
+            ds.n_threads = n_threads
+            seed_points = ds[:]
+
+        with open(os.path.join(debug_folder, 'mapping.json'), 'r') as f:
+            mapping = json.load(f)
+        mapping = {int(k): v for k, v in mapping.items()}
+
+        uv_ids = np.load(os.path.join(debug_folder, 'graph.npy'))
+        n_nodes = int(uv_ids.max()) + 1
+        graph = nifty.graph.undirectedGraph(n_nodes)
+        graph.insertEdges(uv_ids)
+        probs = np.load(os.path.join(debug_folder, 'probs.npy'))
+
+        node_labels = CorrectionTool.graph_watershed(graph, probs, ws, seed_points, mapping)
+
+        ws_ids = np.unique(ws)[1:]
+        label_dict = {wsid: node_labels[mapping[wsid]] for wsid in ws_ids}
+        label_dict[0] = 0
+        seg = nt.takeDict(label_dict, ws)
+
+        view(raw, ws, seg)
diff --git a/scripts/segmentation/correction/export_node_labels.py b/scripts/segmentation/correction/export_node_labels.py
new file mode 100644
index 0000000..7f3b4c3
--- /dev/null
+++ b/scripts/segmentation/correction/export_node_labels.py
@@ -0,0 +1,177 @@
+import os
+from math import ceil, floor
+
+import numpy as np
+import vigra
+from elf.io import open_file
+
+
+def export_node_labels(path, in_key, out_key, project_folder):
+    with open_file(path, 'r') as f:
+        ds = f[in_key]
+        node_labels = ds[:]
+
+    result_folder = os.path.join(project_folder, 'results')
+    res_files = os.listdir(result_folder)
+
+    print("Applying changes for", len(res_files), "resolved objects")
+
+    id_offset = int(node_labels.max()) + 1
+    for resf in res_files:
+        seg_id = int(os.path.splitext(resf)[0])
+        resf = os.path.join(result_folder, resf)
+        res = np.load(resf)
+
+        this_ids, this_labels = res['node_ids'], res['node_labels']
+        assert len(this_ids) == len(this_labels)
+        this_ids_exp = np.where(node_labels == seg_id)[0]
+        assert np.array_equal(np.sort(this_ids), this_ids_exp)
+
+        this_labels += id_offset
+        node_labels[this_ids] = this_labels
+        id_offset = int(this_labels.max()) + 1
+
+    with open_file(path) as f:
+        chunks = (min(int(1e6), len(node_labels)),)
+        ds = f.require_dataset(out_key, compression='gzip', dtype=node_labels.dtype,
+                               chunks=chunks, shape=node_labels.shape)
+        ds[:] = node_labels
+
+
+def to_paintera_format(in_path, in_key, out_path, out_key):
+    with open_file(in_path, 'r') as f:
+        node_labels = f[in_key][:]
+    node_labels = vigra.analysis.relabelConsecutive(node_labels, start_label=1, keep_zeros=True)[0]
+
+    n_ws = len(node_labels)
+    ws_ids = np.arange(n_ws, dtype='uint64')
+    assert len(ws_ids) == len(node_labels)
+
+    seg_ids, seg_counts = np.unique(node_labels, return_counts=True)
+    trivial_segments = seg_ids[seg_counts == 1]
+    trivial_mask = np.in1d(node_labels, trivial_segments)
+
+    ws_ids = ws_ids[~trivial_mask]
+    node_labels = node_labels[~trivial_mask]
+
+    node_labels[node_labels != 0] += n_ws
+
+    max_id = node_labels.max()
+    print("new max id:", max_id)
+
+    paintera_labels = np.concatenate([ws_ids[:, None], node_labels[:, None]], axis=1).T
+    print(paintera_labels.shape)
+
+    with open_file(out_path) as f:
+        chunks = (1, min(paintera_labels.shape[1], int(1e6)))
+        f.create_dataset(out_key, data=paintera_labels, chunks=chunks,
+                         compression='gzip')
+
+
+def zero_out_ids(node_label_in_path, node_label_in_key,
+                 node_label_out_path, node_label_out_key,
+                 zero_ids):
+    with open_file(node_label_in_path, 'r') as f:
+        ds = f[node_label_in_key]
+        node_labels = ds[:]
+        chunks = ds.chunks
+
+    zero_mask = np.isin(node_labels, zero_ids)
+    node_labels[zero_mask] = 0
+
+    with open_file(node_label_out_path) as f:
+        ds = f.require_dataset(node_label_out_key, shape=node_labels.shape, chunks=chunks,
+                               compression='gzip', dtype=node_labels.dtype)
+        ds[:] = node_labels
+
+
+# TODO refactor this properly
+def get_bounding_boxes(table_path, table_key, scale_factor):
+    with open_file(table_path, 'r') as f:
+        table = f[table_key][:]
+    bb_starts = table[:, 5:8]
+    bb_stops = table[:, 8:]
+    bb_starts /= scale_factor
+    bb_stops /= scale_factor
+    bounding_boxes = [tuple(slice(int(floor(sta)),
+                                  int(ceil(sto))) for sta, sto in zip(start, stop))
+                      for start, stop in zip(bb_starts, bb_stops)]
+    return bounding_boxes
+
+
+def check_exported(res_file, raw_path, raw_key, ws_path, ws_key,
+                   table_path, table_key, scale_factor):
+    from heimdall import view
+    import nifty.tools as nt
+    seg_id = int(os.path.splitext(os.path.split(res_file)[1])[0])
+    res = np.load(res_file)
+
+    bb = get_bounding_boxes(table_path, table_key, scale_factor)[seg_id]
+
+    node_ids, node_labels = res['node_ids'], res['node_labels']
+    assert len(node_ids) == len(node_labels)
+
+    with open_file(raw_path, 'r') as f:
+        ds = f[raw_key]
+        ds.n_threads = 8
+        raw = ds[bb]
+    with open_file(ws_path, 'r') as f:
+        ds = f[ws_key]
+        ds.n_threads = 8
+        ws = ds[bb]
+
+    seg_mask = np.isin(ws, node_ids)
+    ws[~seg_mask] = 0
+    label_dict = {wsid: lid for wsid, lid in zip(node_ids, node_labels)}
+    label_dict[0] = 0
+    seg = nt.takeDict(label_dict, ws)
+
+    view(raw, seg)
+
+
+def check_exported_paintera(paintera_path, assignment_key,
+                            node_label_path, node_label_key,
+                            table_path, table_key, scale_factor,
+                            raw_path, raw_key, seg_path, seg_key,
+                            check_ids):
+    from heimdall import view
+    import nifty.tools as nt
+
+    with open_file(paintera_path, 'r') as f:
+        ds = f[assignment_key]
+        new_assignments = ds[:].T
+
+    with open_file(node_label_path, 'r') as f:
+        ds = f[node_label_key]
+        node_labels = ds[:]
+
+    bounding_boxes = get_bounding_boxes(table_path, table_key, scale_factor)
+    with open_file(seg_path, 'r') as fseg, open_file(raw_path, 'r') as fraw:
+        ds_seg = fseg[seg_key]
+        ds_seg.n_thread = 8
+        ds_raw = fraw[raw_key]
+        ds_raw.n_thread = 8
+
+        for seg_id in check_ids:
+            bb = bounding_boxes[seg_id]
+            raw = ds_raw[bb]
+            ws = ds_seg[bb]
+
+            ws_ids = np.where(node_labels == seg_id)[0]
+            seg_mask = np.isin(ws, ws_ids)
+            ws[~seg_mask] = 0
+
+            new_label_mask = np.isin(new_assignments[:, 0], ws_ids)
+            new_label_dict = dict(zip(new_assignments[:, 0][new_label_mask],
+                                      new_assignments[:, 1][new_label_mask]))
+            new_label_dict[0] = 0
+
+            # I am not sure why this happens
+            un_ws = np.unique(ws)
+            un_labels = list(new_label_dict.keys())
+            missing = np.setdiff1d(un_ws, un_labels)
+            print("Number of missing: ")
+            new_label_dict.update({miss: 0 for miss in missing})
+
+            seg_new = nt.takeDict(new_label_dict, ws)
+            view(raw, seg_mask.astype('uint32'), seg_new)
diff --git a/scripts/segmentation/correction/heuristics.py b/scripts/segmentation/correction/heuristics.py
new file mode 100644
index 0000000..8a16ecf
--- /dev/null
+++ b/scripts/segmentation/correction/heuristics.py
@@ -0,0 +1,173 @@
+import json
+from concurrent import futures
+from math import ceil, floor
+
+import numpy as np
+import nifty.distributed as ndist
+import vigra
+
+from scipy.ndimage.morphology import binary_closing, binary_opening, binary_erosion
+from skimage.morphology import convex_hull_image
+from elf.io import open_file
+
+
+#
+# heuristics to find morphology outliers
+#
+
+
+def closing_ratio(seg, n_iters):
+    seg_closed = binary_closing(seg, iterations=n_iters)
+    m1 = float(seg.sum())
+    m2 = float(seg_closed.sum())
+    return m2 / m1
+
+
+def opening_ratio(seg, n_iters):
+    seg_opened = binary_opening(seg, iterations=n_iters)
+    m1 = float(seg.sum())
+    m2 = float(seg_opened.sum())
+    return m1 / m2
+
+
+def convex_hull_ratio(seg):
+    seg_conv = convex_hull_image(seg)
+    m1 = float(seg.sum())
+    m2 = float(seg_conv.sum())
+    return m2 / m1
+
+
+def components_per_slice(seg, n_iters=0):
+    n_components = 0
+    n_slices = seg.shape[0]
+    if n_slices == 0:
+        return 0.
+
+    for z in range(n_slices):
+        if n_iters > 0:
+            segz = binary_erosion(seg[z], iterations=n_iters).astype('uint32')
+        else:
+            segz = seg[z].astype('uint32')
+        n_components += len(np.unique(vigra.analysis.labelImageWithBackground(segz))[1:])
+    n_components /= n_slices
+    return n_components
+
+
+def read_bb_from_table(table_path, table_key, scale_factor):
+    with open_file(table_path, 'r') as f:
+        table = f[table_key][:]
+    bb_starts = table[:, 5:8] / scale_factor
+    bb_stops = table[:, 8:] / scale_factor
+    bbs = [tuple(slice(int(floor(sta)),
+                       int(ceil(sto))) for sta, sto in zip(starts, stops))
+           for starts, stops in zip(bb_starts, bb_stops)]
+    return bbs
+
+
+def compute_ratios(seg_path, seg_key, table_path, table_key,
+                   scale_factor, n_threads,
+                   compute_ratio, seg_ids=None, sort=False):
+    with open_file(seg_path, 'r') as f:
+        ds = f[seg_key]
+        bounding_boxes = read_bb_from_table(table_path, table_key, scale_factor)
+
+        if seg_ids is None:
+            seg_ids = np.arange(len(bounding_boxes))
+        n_seg_ids = len(seg_ids)
+
+        def _compute_ratio(seg_id):
+            print("%i / %i" % (seg_id, n_seg_ids))
+            bb = bounding_boxes[seg_id]
+            seg = ds[bb]
+            seg = seg == seg_id
+            ratio = compute_ratio(seg)
+            return ratio
+
+        with futures.ThreadPoolExecutor(n_threads) as tp:
+            tasks = [tp.submit(_compute_ratio, seg_id) for seg_id in seg_ids]
+            ratios = np.array([t.result() for t in tasks])
+
+    if sort:
+        sorted_ids = np.argsort(ratios)[::-1]
+        seg_ids = seg_ids[sorted_ids]
+        ratios = ratios[sorted_ids]
+
+    return seg_ids, ratios
+
+
+#
+# heuristics to find likely false merges
+#
+
+def get_ignore_ids(label_path, label_key,
+                   ignore_names=['yolk', 'cuticle', 'neuropil']):
+    with open_file(label_path, 'r') as f:
+        ds = f[label_key]
+        semantics = ds.attrs['semantics']
+        labels = ds[:]
+    ignore_label_ids = []
+    for name, ids in semantics.items():
+        if name in ignore_names:
+            ignore_label_ids.extend(ids)
+    ignore_ids = np.isin(labels, ignore_label_ids)
+    ignore_ids = np.where(ignore_ids)[0]
+    return ignore_ids
+
+
+def weight_quantile_heuristic(seg_id, graph, node_labels, sizes, max_size, weights,
+                              quantile=90):
+    size_ratio = float(sizes[seg_id]) / max_size
+    node_ids = np.where(node_labels == seg_id)[0]
+    edges = graph.extractSubgraphFromNodes(node_ids, allowInvalidNodes=True)[0]
+    this_weights = weights[edges]
+    try:
+        score = np.percentile(this_weights, quantile) * size_ratio
+    except IndexError:
+        print("Something went wrong", seg_id, this_weights.shape)
+        score = 0.
+    return score
+
+
+def rank_false_merges(problem_path, graph_key, feat_key,
+                      morpho_key, node_label_path, node_label_key,
+                      ignore_ids, out_path_ids, out_path_scores,
+                      n_threads, n_candidates, heuristic=weight_quantile_heuristic):
+    g = ndist.Graph(problem_path, graph_key, n_threads)
+    with open_file(problem_path, 'r') as f:
+        ds = f[feat_key]
+        ds.n_threads = n_threads
+        probs = ds[:, 0]
+
+        ds = f[morpho_key]
+        ds.n_threads = n_threads
+        sizes = ds[:, 1]
+
+    with open_file(node_label_path, 'r') as f:
+        ds = f[node_label_key]
+        ds.n_threads = n_threads
+        node_labels = ds[:]
+
+    seg_ids = np.arange(len(sizes), dtype='uint64')
+    seg_ids = seg_ids[np.argsort(sizes)[::-1]][:n_candidates]
+    seg_ids = seg_ids[~np.isin(seg_ids, ignore_ids.tolist() + [0])]
+    max_size = sizes[seg_ids].max()
+    with futures.ThreadPoolExecutor(n_threads) as tp:
+        tasks = [tp.submit(weight_quantile_heuristic, seg_id, g,
+                           node_labels, sizes, max_size, probs) for seg_id in seg_ids]
+        fm_scores = np.array([t.result() for t in tasks])
+
+    # print("Id:", seg_ids[0])
+    # sc = weight_quantile_heuristic(seg_ids[0], g,
+    #                                node_labels, sizes, max_size, probs)
+    # print("Score:", sc)
+    # return
+
+    # sort ids by score (decreasing)
+    sorter = np.argsort(fm_scores)[::-1]
+    seg_ids = seg_ids[sorter]
+    fm_scores = fm_scores[sorter]
+
+    with open(out_path_scores, 'w') as f:
+        json.dump(fm_scores.tolist(), f)
+    with open(out_path_ids, 'w') as f:
+        json.dump(seg_ids.tolist(), f)
diff --git a/scripts/segmentation/correction/mark_false_merges.py b/scripts/segmentation/correction/mark_false_merges.py
deleted file mode 100644
index e05bfc7..0000000
--- a/scripts/segmentation/correction/mark_false_merges.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import os
-import json
-import napari
-from elf.io import open_file
-
-
-# TODO
-# mark false merges in segmentation:
-# open napari viewer with raw data and segmentation (in appropriate res)
-# set annotations in 'has_false_merge' and 'is_correct' layer
-# clear segmentations that have an annotation (upon key press)
-# store annotations in the project folder
-class MarkFalseMerges:
-    def __init__(self, project_folder,
-                 raw_path=None, raw_key=None,
-                 seg_path=None, seg_key=None):
-        pass
-
-    def __call__(self):
-        pass
diff --git a/scripts/segmentation/correction/preprocess.py b/scripts/segmentation/correction/preprocess.py
new file mode 100644
index 0000000..0522800
--- /dev/null
+++ b/scripts/segmentation/correction/preprocess.py
@@ -0,0 +1,117 @@
+import os
+import json
+import luigi
+from elf.io import open_file
+
+from paintera_tools import serialize_from_commit
+from paintera_tools.util import compute_graph_and_weights
+from cluster_tools.node_labels import NodeLabelWorkflow
+from cluster_tools.morphology import MorphologyWorkflow
+from cluster_tools.downscaling import DownscalingWorkflow
+
+
+def graph_and_costs(path, ws_key, aff_key, out_path):
+    tmp_folder = './tmp_preprocess'
+    compute_graph_and_weights(path, aff_key,
+                              path, ws_key, out_path,
+                              tmp_folder, target='slurm', max_jobs=250,
+                              offsets=[[-1, 0, 0], [0, -1, 0], [0, 0, -1]],
+                              with_costs=True)
+
+
+def accumulate_node_labels(ws_path, ws_key, seg_path, seg_key,
+                           out_path, out_key, prefix):
+    task = NodeLabelWorkflow
+
+    tmp_folder = './tmp_preprocess'
+    config_dir = os.path.join(tmp_folder, 'configs')
+
+    t = task(tmp_folder=tmp_folder, config_dir=config_dir,
+             target='slurm', max_jobs=250,
+             ws_path=ws_path, ws_key=ws_key,
+             input_path=seg_path, input_key=seg_key,
+             output_path=out_path, output_key=out_key,
+             prefix=prefix)
+    ret = luigi.build([t], local_scheduler=True)
+    assert ret
+
+
+def compute_bounding_boxes(path, key):
+    task = MorphologyWorkflow
+    tmp_folder = './tmp_preprocess'
+    config_dir = os.path.join(tmp_folder, 'configs')
+
+    out_key = 'morphology'
+    t = task(tmp_folder=tmp_folder, config_dir=config_dir,
+             target='slurm', max_jobs=250,
+             input_path=path, input_key=key,
+             output_path=path, output_key=out_key)
+    ret = luigi.build([t], local_scheduler=True)
+    assert ret
+
+
+def downscale_segmentation(path, key):
+    task = DownscalingWorkflow
+
+    tmp_folder = './tmp_preprocess'
+    config_dir = os.path.join(tmp_folder, 'configs')
+
+    configs = task.get_config()
+    conf = configs['downscaling']
+    conf.update({'library_kwargs': {'order': 0}})
+    with open(os.path.join(config_dir, 'downscaling.config'), 'w') as f:
+        json.dump(conf, f)
+
+    in_key = os.path.join(key, 's0')
+    n_scales = 5
+    scales = n_scales * [[2, 2, 2]]
+    halos = n_scales * [[0, 0, 0]]
+
+    t = task(tmp_folder=tmp_folder, config_dir=config_dir,
+             # target='slurm', max_jobs=250,
+             target='local', max_jobs=64,
+             input_path=path, input_key=in_key,
+             scale_factors=scales, halos=halos,
+             output_path=path, output_key_prefix=key)
+    ret = luigi.build([t], local_scheduler=True)
+    assert ret
+
+
+def copy_tissue_labels(in_path, out_path, out_key):
+    with open_file(in_path, 'r') as f:
+        names = f['semantic_names'][:]
+        mapping = f['semantic_mapping'][:]
+
+    semantics = {name: ids.tolist() for name, ids in zip(names, mapping)}
+    with open_file(out_path) as f:
+        ds = f[out_key]
+        ds.attrs['semantics'] = semantics
+
+
+# preprocess:
+# - export current paintera segmentation
+# - build graph and compute weights for current superpixels
+# - get current node labeling
+# - compute bounding boxes for current segments
+# - downscale the segmentation
+def preprocess(path, key, aff_key,
+               tissue_path, tissue_key,
+               out_path, out_key):
+    tmp_folder = './tmp_preprocess'
+    out_key0 = os.path.join(out_key, 's0')
+
+    serialize_from_commit(path, key, out_path, out_key0, tmp_folder,
+                          max_jobs=250, target='slurm', relabel_output=True)
+
+    ws_key = os.path.join(key, 'data', 's0')
+    graph_and_costs(path, ws_key, aff_key, out_path)
+
+    accumulate_node_labels(path, ws_key, out_path, out_key0,
+                           out_path, 'node_labels', prefix='node_labels')
+
+    accumulate_node_labels(out_path, out_key0, tissue_path, tissue_key,
+                           out_path, 'tissue_labels', prefix='tissue')
+    copy_tissue_labels(tissue_path, out_path, 'tissue_labels')
+
+    compute_bounding_boxes(out_path, out_key0)
+    downscale_segmentation(out_path, out_key)
diff --git a/scripts/segmentation/validation/eval_nuclei.py b/scripts/segmentation/validation/eval_nuclei.py
index 8b56dc4..86b985f 100644
--- a/scripts/segmentation/validation/eval_nuclei.py
+++ b/scripts/segmentation/validation/eval_nuclei.py
@@ -1,6 +1,72 @@
-from .evaluate_annotations import evaluate_annotations
+from math import ceil, floor
+import vigra
+from elf.io import open_file, is_dataset
+from .evaluate_annotations import evaluate_annotations, merge_evaluations
 
 
-# TODO
-def eval_nuclei():
-    pass
+def get_bounding_box(ds, scale_factor):
+    attrs = ds.attrs
+    start, stop = attrs['starts'], attrs['stops']
+    bb = tuple(slice(int(floor(sta / scale_factor)),
+                     int(ceil(sto / scale_factor))) for sta, sto in zip(start, stop))
+    return bb
+
+
+def to_scores(eval_res):
+    n = float(eval_res['n_annotations'] - eval_res['n_unmatched'])
+    fp = eval_res['n_splits']
+    fn = eval_res['n_merged_annotations']
+    return fp / n, fn / n, n
+
+
+# we may want to do a max projection of some z context ?!
+def get_nucleus_segmentation(ds_seg, bb):
+    seg = ds_seg[bb].squeeze().astype('uint32')
+    return seg
+
+
+# need to downsample the annotations and bounding box to fit the
+# nucleus segmentation
+def eval_slice(ds_seg, ds_ann, min_radius, return_masks=False):
+    ds_seg.n_threads = 8
+    ds_ann.n_threads = 8
+
+    bb = get_bounding_box(ds_ann, scale_factor=4.)
+    annotations = ds_ann[:]
+    seg = get_nucleus_segmentation(ds_seg, bb)
+    annotations = vigra.sampling.resize(annotations.astype('float32'),
+                                        shape=seg.shape, order=0).astype('uint32')
+
+    fg_annotations = (annotations == 1).astype('uint32')
+    bg_annotations = None
+
+    return evaluate_annotations(seg, fg_annotations, bg_annotations,
+                                min_radius=min_radius, return_masks=return_masks)
+
+
+def eval_nuclei(seg_path, seg_key,
+                annotation_path, annotation_key=None,
+                min_radius=6):
+    """ Evaluate the nucleus segmentation by computing
+    the percentage of false positive and false negative nucleus annotations
+    in manually annotated validation slices.
+    """
+    eval_res = {}
+    with open_file(seg_path, 'r') as f_seg, open_file(annotation_path, 'r') as f_ann:
+        ds_seg = f_seg[seg_key]
+        g = f_ann if annotation_key is None else f_ann[annotation_key]
+
+        def visit_annotation(name, node):
+            nonlocal eval_res
+            if is_dataset(node):
+                print("Evaluating:", name)
+                res = eval_slice(ds_seg, node, min_radius)
+                eval_res = merge_evaluations(res, eval_res)
+                # for debugging
+                # print("current eval:", eval_res)
+            else:
+                print("Group:", name)
+
+        g.visititems(visit_annotation)
+
+    return to_scores(eval_res)
diff --git a/scripts/segmentation/validation/evaluate_annotations.py b/scripts/segmentation/validation/evaluate_annotations.py
index 4fda835..ab4ffd0 100644
--- a/scripts/segmentation/validation/evaluate_annotations.py
+++ b/scripts/segmentation/validation/evaluate_annotations.py
@@ -22,7 +22,7 @@ def get_radii(seg):
     return radii
 
 
-def evaluate_annotations(seg, fg_annotations, bg_annotations,
+def evaluate_annotations(seg, fg_annotations, bg_annotations=None,
                          ignore_seg_ids=None, min_radius=16,
                          return_masks=False, return_ids=False):
     """ Evaluate segmentation based on evaluations.
@@ -54,7 +54,7 @@ def evaluate_annotations(seg, fg_annotations, bg_annotations,
         # check if this is an ignore id and skip
         if ignore_seg_ids is not None and seg_id in ignore_seg_ids:
             continue
-        has_bg_label = bg_annotations[mask].sum() > 0
+        has_bg_label = False if bg_annotations is None else bg_annotations[mask].sum() > 0
 
         # find the overlapping label ids
         this_labels = np.unique(labels[mask])
diff --git a/scripts/segmentation/validation/refine_annotations.py b/scripts/segmentation/validation/refine_annotations.py
index f58df90..c8fbf2f 100644
--- a/scripts/segmentation/validation/refine_annotations.py
+++ b/scripts/segmentation/validation/refine_annotations.py
@@ -33,8 +33,8 @@ def compute_masks(seg, labels, ignore_seg_ids):
 def refine(seg_path, seg_key, ignore_seg_ids,
            orientation, slice_id,
            project_folder,
-           annotation_path='/g/arendt/...',
-           raw_path='/g/arendt/...',
+           annotation_path='/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation/validation_annotations.h5',
+           raw_path='/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/sbem-6dpf-1-whole-raw.h5',
            raw_key='t00000/s00/1/cells'):
 
     label_path = os.path.join(project_folder, 'labels.npy')
@@ -47,6 +47,8 @@ def refine(seg_path, seg_key, ignore_seg_ids,
         labels = np.load(label_path) if os.path.exists(label_path) else None
         fm = np.load(fm_path) if os.path.exists(fm_path) else None
         fs = np.load(fs_path) if os.path.exists(fs_path) else None
+    else:
+        labels, fm, fs = None, None, None
 
     with open_file(annotation_path, 'r') as fval:
         ds = fval[orientation][str(slice_id)]
@@ -55,8 +57,8 @@ def refine(seg_path, seg_key, ignore_seg_ids,
         if labels is None:
             labels = ds[:]
 
-    starts = [b.start for b in bb]
-    stops = [b.stop for b in bb]
+    starts = [int(b.start) for b in bb]
+    stops = [int(b.stop) for b in bb]
 
     with open_file(seg_path, 'r') as f:
         ds = f[seg_key]
@@ -68,7 +70,7 @@ def refine(seg_path, seg_key, ignore_seg_ids,
         ds.n_threads = 8
         raw = ds[bb].squeeze()
 
-    assert labels.shape == seg.shape
+    assert labels.shape == seg.shape == raw.shape
     if fm is None:
         assert fs is None
         fm, fs = compute_masks(seg, labels, ignore_seg_ids)
diff --git a/segmentation/correction/correct_cell_segmentation.py b/segmentation/correction/correct_cell_segmentation.py
new file mode 100644
index 0000000..d3b88ac
--- /dev/null
+++ b/segmentation/correction/correct_cell_segmentation.py
@@ -0,0 +1,175 @@
+import os
+import json
+from scripts.segmentation.correction import (preprocess,
+                                             AnnotationTool,
+                                             CorrectionTool,
+                                             export_node_labels,
+                                             to_paintera_format,
+                                             rank_false_merges,
+                                             get_ignore_ids)
+
+
+def run_preprocessing(project_folder):
+    os.makedirs(project_folder, exist_ok=True)
+
+    path = '/g/kreshuk/data/arendt/platyneris_v1/data.n5'
+    key = 'volumes/paintera/proofread_cells'
+    aff_key = 'volumes/curated_affinities/s1'
+
+    out_path = os.path.join(project_folder, 'data.n5')
+    out_key = 'segmentation'
+
+    tissue_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data',
+                               'rawdata/sbem-6dpf-1-whole-segmented-tissue-labels.h5')
+    tissue_key = 't00000/s00/0/cells'
+
+    preprocess(path, key, aff_key, tissue_path, tissue_key,
+               out_path, out_key)
+
+
+def run_heuristics(project_folder):
+    path = os.path.join(project_folder, 'data.n5')
+    ignore_ids = get_ignore_ids(path, 'tissue_labels')
+
+    out_path_ids = os.path.join(project_folder, 'fm_candidate_ids.json')
+    out_path_scores = os.path.join(project_folder, 'fm_candidate_scores.json')
+    n_threads = 32
+    n_candidates = 10000
+
+    rank_false_merges(path, 's0/graph', 'features',
+                      'morphology', path, 'node_labels',
+                      ignore_ids, out_path_ids, out_path_scores,
+                      n_threads, n_candidates)
+
+
+def run_annotations(id_path, project_folder, scale=2,
+                    with_node_labels=True):
+    p1 = os.path.join(project_folder, 'data.n5')
+    table_key = 'morphology'
+    scale_factor = 2 ** scale
+
+    p2 = '/g/kreshuk/data/arendt/platyneris_v1/data.n5'
+    rk = 'volumes/raw/s%i' % (scale + 1,)
+
+    if with_node_labels:
+        node_label_path = p1
+        node_label_key = 'node_labels'
+        wsp = p2
+        wsk = 'volumes/paintera/proofread_cells/data/s%i' % scale
+    else:
+        node_label_path, node_label_key = None, None
+        wsp = p1
+        wsk = 'segmentation/s%i' % scale
+
+    annotator = AnnotationTool(project_folder, id_path,
+                               p1, table_key, scale_factor,
+                               p2, rk, wsp, wsk,
+                               node_label_path, node_label_key)
+    annotator()
+
+
+def filter_fm_ids_by_annotations(project_folder, annotation_path):
+    annotation = 'revisit'
+    with open(annotation_path) as f:
+        annotations = json.load(f)
+    annotations = {int(k): v for k, v in annotations.items()}
+    ids = [k for k, v in annotations.items() if v == annotation]
+    print("Found", len(ids), "ids with annotation", annotation)
+    out_path = os.path.join(project_folder, 'fm_ids_filtered.json')
+    with open(out_path, 'w') as f:
+        json.dump(ids, f)
+    return out_path
+
+
+def run_correction(project_folder, fm_id_path, scale=2):
+    p1 = os.path.join(project_folder, 'data.n5')
+    table_key = 'morphology'
+    segk = 'segmentation/s%i' % scale
+    scale_factor = 2 ** scale
+
+    gk = 's0/graph'
+    fk = 'features'
+
+    p2 = '/g/kreshuk/data/arendt/platyneris_v1/data.n5'
+    rk = 'volumes/raw/s%i' % (scale + 1,)
+    wsk = 'volumes/paintera/proofread_cells/data/s%i' % scale
+
+    # table path
+    correcter = CorrectionTool(project_folder, fm_id_path,
+                               p1, table_key, scale_factor,
+                               p2, rk, p1, segk, p2, wsk,
+                               p1, 'node_labels',
+                               p1, gk, fk)
+    correcter()
+
+
+def export_correction(project_folder, correct_merges=True, zero_out=True):
+    from scripts.segmentation.correction.export_node_labels import zero_out_ids
+
+    p = os.path.join(project_folder, 'data.n5')
+    in_key = 'node_labels'
+    out_key = 'node_labels_corrected'
+
+    if correct_merges:
+        print("Correcting merged ids")
+        project_folder = './proj_correct2'
+        export_node_labels(p, in_key, out_key, project_folder)
+        next_in_key = out_key
+    else:
+        next_in_key = in_key
+
+    def get_zero_ids(annotation_path):
+        annotation = 'merge'
+        with open(annotation_path) as f:
+            annotations = json.load(f)
+        annotations = {int(k): v for k, v in annotations.items()}
+        zero_ids = [k for k, v in annotations.items() if v == annotation]
+        return zero_ids
+
+    if zero_out:
+        # read merge annotations from the morphology annotator
+        annotation_path = './proj_annotate_morphology/annotations.json'
+        zero_ids = get_zero_ids(annotation_path)
+        # read additional merge annotations
+        annotation_path = './proj_correct2/annotations.json'
+        zero_ids += get_zero_ids(annotation_path)
+        print("Zeroing out", len(zero_ids), "ids")
+
+        zero_out_ids(p, next_in_key, p, out_key, zero_ids)
+
+    paintera_path = '/g/kreshuk/data/arendt/platyneris_v1/data.n5'
+    paintera_key = 'volumes/paintera/proofread_cells_multiset/fragment-segment-assignment'
+    print("Exporting to paintera format")
+    to_paintera_format(p, out_key, paintera_path, paintera_key)
+
+
+def correction_workflow_from_ranked_false_merges():
+    # the project folder where we store intermediate results
+    # for the false merge correction
+    project_folder = './project_correct_false_merges'
+
+    # compute the segmentation, the region adjacency graph and the graph weights
+    run_preprocessing(project_folder)
+
+    # compute the false merge heuristics based on a morphology criterion
+    # (default: number of connected components per slice)
+    fm_id_path = run_heuristics(project_folder)
+
+    # run annotations to quickly filter for the false merges that need to be
+    # corrected. this is not strictly necessary (we can use run_correction directly)
+    # but in my experience, it is faster to first just filter for false merges
+    # and then to correct them.
+    run_annotations(fm_id_path, project_folder)
+    annotation_path = os.path.join(project_folder, 'annotations.json')
+    fm_id_path = filter_fm_ids_by_annotations(project_folder,
+                                              fm_id_path, annotation_path)
+
+    # run the false merge correction tool
+    run_correction(project_folder, fm_id_path)
+
+    # export the corrected node labels to paintera
+    export_correction(project_folder)
+
+
+if __name__ == '__main__':
+    correction_workflow_from_ranked_false_merges()
diff --git a/segmentation/correction/deprecated.py b/segmentation/correction/deprecated.py
new file mode 100644
index 0000000..a03ded7
--- /dev/null
+++ b/segmentation/correction/deprecated.py
@@ -0,0 +1,47 @@
+#
+# just keeping this here for reference
+#
+
+import os
+import json
+
+
+def get_skipped_ids(out_path):
+    processed_ids = './proj_correct_fms/processed_ids.json'
+    res_dir = './proj_correct_fms/results'
+    with open(processed_ids) as f:
+        processed_ids = json.load(f)
+
+    saved_ids = os.listdir(res_dir)
+    saved_ids = [int(os.path.splitext(sid)[0]) for sid in saved_ids]
+
+    skipped_ids = list(set(processed_ids) - set(saved_ids))
+    print("N-skipped:")
+    print(len(skipped_ids))
+
+    with open(out_path, 'w') as f:
+        json.dump(skipped_ids, f)
+
+
+def check_export():
+    from scripts.segmentation.correction.export_node_labels import check_exported_paintera
+
+    path1 = '/g/kreshuk/data/arendt/platyneris_v1/data.n5'
+    assignment_key = 'volumes/paintera/proofread_cells/fragment-segment-assignment'
+    raw_key = 'volumes/raw/s3'
+    seg_key = 'volumes/paintera/proofread_cells/data/s2'
+
+    path2 = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/analysis/correction/data.n5'
+    node_label_key = 'node_labels'
+    table_key = 'morphology'
+    scale_factor = 4
+
+    res_folder = './proj_correct_fms/results'
+    ids = os.listdir(res_folder)[:10]
+    ids = [int(os.path.splitext(idd)[0]) for idd in ids]
+
+    check_exported_paintera(path1, assignment_key,
+                            path2, node_label_key,
+                            path2, table_key, scale_factor,
+                            path1, raw_key, path1, seg_key,
+                            ids)
diff --git a/segmentation/correction/fix_assignments.py b/segmentation/correction/fix_assignments.py
new file mode 100644
index 0000000..e8ad6ee
--- /dev/null
+++ b/segmentation/correction/fix_assignments.py
@@ -0,0 +1,69 @@
+import os
+import json
+
+
+def get_assignments(version):
+    from scripts.segmentation.correction.assignment_diffs import node_labels
+    assert version in ('0.5.5', '0.6.1', 'local')
+    ws_path = '/g/kreshuk/data/arendt/platyneris_v1/data.n5'
+    ws_key = 'volumes/paintera/proofread_cells_multiset/data/s0'
+
+    if version == 'local':
+        in_path = './data.n5'
+        in_key = 'segmentation/s0'
+        prefix = version
+    else:
+        in_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data',
+                               '%s/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5' % version)
+        in_key = 't00000/s00/0/cells'
+
+        prefix = ''.join(version.split('.'))
+
+    out_path = './data.n5'
+    out_key = 'assignments/%s' % prefix
+
+    tmp_folder = './tmp_%s' % prefix
+
+    node_labels(ws_path, ws_key, in_path, in_key,
+                out_path, out_key, prefix, tmp_folder)
+
+
+def get_split_assignments():
+    import z5py
+    from scripts.segmentation.correction.assignment_diffs import assignment_diff_splits
+    with z5py.File('./data.n5') as f:
+        ref = f['assignments/055'][:]
+        new = f['assignments/corrected'][:]
+
+    splits = assignment_diff_splits(ref, new)
+
+    print(len(splits))
+
+    # with open('./split_assignments_local.json', 'w') as f:
+    #     json.dump(splits, f)
+
+
+def blub():
+    with open('./split_assignments.json') as f:
+        split_assignments = json.load(f)
+    ids = list(split_assignments.keys())
+    # vals = list(split_assignments.values())
+
+    correction_path = './proj_correct2/results'
+    correct_ids = os.listdir(correction_path)
+    correct_ids = [int(os.path.splitext(cid)[0]) for cid in correct_ids]
+
+    print("Number of assignments which are split:", len(ids))
+    print("Number of assignments which were supposed to be split:", len(correct_ids))
+
+    diff = set(ids) - set(correct_ids)
+    print("Number after diff:", len(diff))
+
+
+if __name__ == '__main__':
+    # get_assignments('0.5.5')
+    # get_assignments('0.6.1')
+    # get_assignments('local')
+
+    get_split_assignments()
+    # blub()
diff --git a/segmentation/correction/morphology_outlier.py b/segmentation/correction/morphology_outlier.py
new file mode 100644
index 0000000..5a2cdfd
--- /dev/null
+++ b/segmentation/correction/morphology_outlier.py
@@ -0,0 +1,53 @@
+import os
+import json
+import numpy as np
+import z5py
+from scripts.segmentation.correction import preprocess
+from scripts.segmentation.correction.heuristics import (compute_ratios,
+                                                        components_per_slice,
+                                                        get_ignore_ids)
+
+
+def run_preprocessing():
+    path = '/g/kreshuk/data/arendt/platyneris_v1/data.n5'
+    key = 'volumes/paintera/proofread_cells'
+    aff_key = 'volumes/curated_affinities/s1'
+
+    out_path = './data.n5'
+    out_key = 'segmentation'
+
+    tissue_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data',
+                               'rawdata/sbem-6dpf-1-whole-segmented-tissue-labels.h5')
+    tissue_key = 't00000/s00/0/cells'
+
+    preprocess(path, key, aff_key, tissue_path, tissue_key,
+               out_path, out_key)
+
+
+def morphology_outlier():
+    scale = 2
+    p1 = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/analysis/correction/data.n5'
+    seg_key = 'segmentation/s%i' % scale
+    table_key = 'morphology'
+    scale_factor = 2 ** scale
+
+    with z5py.File(p1, 'r') as f:
+        ds = f['segmentation/s0']
+        n_ids = ds.attrs['maxId'] + 1
+
+    seg_ids = np.arange(n_ids)
+    ignore_ids = get_ignore_ids(p1, 'tissue_labels')
+    ignore_id_mask = ~np.isin(seg_ids, ignore_ids)
+    seg_ids = seg_ids[ignore_id_mask]
+
+    seg_ids, ratios_close = compute_ratios(p1, seg_key, p1, table_key, scale_factor,
+                                           64, components_per_slice, seg_ids, True)
+    with open('./ratios_components.json', 'w') as f:
+        json.dump(ratios_close.tolist(), f)
+    with open('./ids_morphology.json', 'w') as f:
+        json.dump(seg_ids.tolist(), f)
+
+
+if __name__ == '__main__':
+    # run_preprocessing()
+    morphology_outlier()
diff --git a/segmentation/validation/check_validation.py b/segmentation/validation/check_validation.py
new file mode 100644
index 0000000..d85f41f
--- /dev/null
+++ b/segmentation/validation/check_validation.py
@@ -0,0 +1,106 @@
+import os
+import h5py
+from heimdall import view, to_source
+from scripts.segmentation.validation import eval_cells, eval_nuclei, get_ignore_seg_ids
+
+# ROOT_FOLDER = '../../data'
+ROOT_FOLDER = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data'
+
+
+def check_cell_evaluation():
+    from scripts.segmentation.validation.eval_cells import (eval_slice,
+                                                            get_bounding_box)
+
+    praw = os.path.join(ROOT_FOLDER, 'rawdata/sbem-6dpf-1-whole-raw.h5')
+    pseg = os.path.join(ROOT_FOLDER, '0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5')
+    pann = os.path.join(ROOT_FOLDER, 'rawdata/evaluation/validation_annotations.h5')
+
+    table_path = '../../data/0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv'
+    ignore_seg_ids = get_ignore_seg_ids(table_path)
+
+    with h5py.File(pseg, 'r') as fseg, h5py.File(pann, 'r') as fann:
+        ds_seg = fseg['t00000/s00/0/cells']
+        ds_ann = fann['xy/1000']
+
+        print("Run evaluation ...")
+        res, masks = eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius=16,
+                                return_masks=True)
+        fm, fs = masks['merges'], masks['splits']
+        print()
+        print("Eval result")
+        print(res)
+        print()
+
+        print("Load raw data ...")
+        bb = get_bounding_box(ds_ann)
+        with h5py.File(praw, 'r') as f:
+            raw = f['t00000/s00/1/cells'][bb].squeeze()
+
+        print("Load seg data ...")
+        seg = ds_seg[bb].squeeze().astype('uint32')
+
+        view(to_source(raw, name='raw'), to_source(seg, name='seg'),
+             to_source(fm, name='merges'), to_source(fs, name='splits'))
+
+
+def eval_all_cells():
+    pseg = os.path.join(ROOT_FOLDER, '0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5')
+    pann = os.path.join(ROOT_FOLDER, 'rawdata/evaluation/validation_annotations.h5')
+
+    table_path = os.path.join(ROOT_FOLDER, '0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv')
+    ignore_seg_ids = get_ignore_seg_ids(table_path)
+
+    res = eval_cells(pseg, 't00000/s00/0/cells', pann,
+                     ignore_seg_ids=ignore_seg_ids)
+    print("Eval result:")
+    print(res)
+
+
+def check_nucleus_evaluation():
+    from scripts.segmentation.validation.eval_nuclei import (eval_slice,
+                                                             get_bounding_box)
+
+    praw = os.path.join(ROOT_FOLDER, 'rawdata/sbem-6dpf-1-whole-raw.h5')
+    pseg = os.path.join(ROOT_FOLDER, '0.0.0/segmentations/sbem-6dpf-1-whole-segmented-nuclei-labels.h5')
+    pann = os.path.join(ROOT_FOLDER, 'rawdata/evaluation/validation_annotations.h5')
+
+    with h5py.File(pseg, 'r') as fseg, h5py.File(pann, 'r') as fann:
+        ds_seg = fseg['t00000/s00/0/cells']
+        ds_ann = fann['xy/1000']
+
+        print("Run evaluation ...")
+        res, masks = eval_slice(ds_seg, ds_ann, min_radius=6,
+                                return_masks=True)
+        fm, fs = masks['merges'], masks['splits']
+        print()
+        print("Eval result")
+        print(res)
+        print()
+
+        print("Load raw data ...")
+        bb = get_bounding_box(ds_ann, scale_factor=4.)
+        with h5py.File(praw, 'r') as f:
+            raw = f['t00000/s00/3/cells'][bb].squeeze()
+
+        print("Load seg data ...")
+        seg = ds_seg[bb].squeeze().astype('uint32')
+
+        view(to_source(raw, name='raw'), to_source(seg, name='seg'),
+             to_source(fm, name='merges'), to_source(fs, name='splits'))
+
+
+def eval_all_nulcei():
+    pseg = os.path.join(ROOT_FOLDER, '0.0.0/segmentations/sbem-6dpf-1-whole-segmented-nuclei-labels.h5')
+    pann = os.path.join(ROOT_FOLDER, 'rawdata/evaluation/validation_annotations.h5')
+
+    res = eval_nuclei(pseg, 't00000/s00/0/cells', pann)
+    print("Eval result:")
+    print(res)
+
+
+if __name__ == '__main__':
+    # check_cell_evaluation()
+    eval_all_cells()
+
+    # check_nucleus_evaluation()
+    # eval_all_nulcei()
diff --git a/segmentation/validation/eval_cell_segmentation.py b/segmentation/validation/eval_cell_segmentation.py
new file mode 100644
index 0000000..20e1db5
--- /dev/null
+++ b/segmentation/validation/eval_cell_segmentation.py
@@ -0,0 +1,94 @@
+import argparse
+import os
+import numpy as np
+import h5py
+from scripts.segmentation.validation import eval_cells, get_ignore_seg_ids
+from scripts.attributes.region_attributes import region_attributes
+from scripts.default_config import write_default_global_config
+
+ANNOTATIONS = '../../data/rawdata/evaluation/validation_annotations.h5'
+BASELINES = '../../data/rawdata/evaluation/baseline_cell_segmentations.h5'
+
+
+def get_label_ids(path, key):
+    with h5py.File(path, 'r') as f:
+        ds = f[key]
+        max_id = ds.attrs['maxId']
+    label_ids = np.arange(max_id + 1)
+    return label_ids
+
+
+def compute_baseline_tables():
+    names = ['lmc', 'mc', 'curated_lmc', 'curated_mc']
+    path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation',
+                        'baseline_cell_segmentations.h5')
+    table_prefix = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation'
+    im_folder = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.6.0/images'
+    seg_folder = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.6.0/segmentations'
+    for name in names:
+        key = name
+        out_path = os.path.join(table_prefix, '%s.csv' % name)
+        tmp_folder = './tmp_regions_%s' % name
+        config_folder = os.path.join(tmp_folder, 'configs')
+        write_default_global_config(config_folder)
+        label_ids = get_label_ids(path, key)
+        region_attributes(path, out_path, im_folder, seg_folder,
+                          label_ids, tmp_folder, target='local', max_jobs=64,
+                          key_seg=key)
+
+
+def eval_seg(path, key, table):
+    ignore_ids = get_ignore_seg_ids(table)
+    fm, fs, tot = eval_cells(path, key, ANNOTATIONS,
+                             ignore_seg_ids=ignore_ids)
+    print("Evaluation yields:")
+    print("False merges:", fm)
+    print("False splits:", fs)
+    print("Total number of annotations:", tot)
+
+
+def eval_baselines():
+    path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation',
+                        'baseline_cell_segmentations.h5')
+    names = ['lmc', 'mc', 'curated_lmc', 'curated_mc']
+    table_prefix = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation'
+    results = {}
+    for name in names:
+        print("Run evaluation for %s ..." % name)
+        table = os.path.join(table_prefix, '%s.csv' % name)
+        ignore_ids = get_ignore_seg_ids(table)
+        key = name
+        fm, fs, tot = eval_cells(path, key, ANNOTATIONS,
+                                 ignore_seg_ids=ignore_ids)
+        results[name] = (fm, fs, tot)
+
+    for name in names:
+        print("Evaluation of", name, "yields:")
+        print("False merges:", fm)
+        print("False splits:", fs)
+        print("Total number of annotations:", tot)
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("path", type=str, help="Path to segmentation that should be validated.")
+    parser.add_argument("table", type=str, help="Path to table with region/semantic assignments")
+    parser.add_argument("--key", type=str, default="t00000/s00/0/cells", help="Segmentation key")
+    parser.add_argument("--baselines", type=int, default=0,
+                        help="Whether to evaluate the baseline segmentations (overrides path)")
+    args = parser.parse_args()
+
+    baselines = bool(args.baselines)
+    if baselines:
+        eval_baselines()
+    else:
+        path = args.path
+        table = args.table
+        key = args.key
+        assert os.path.exists(path), path
+        eval_seg(path, key, table)
+
+
+if __name__ == '__main__':
+    # compute_baseline_tables()
+    main()
diff --git a/segmentation/validation/eval_nucleus_segmentation.py b/segmentation/validation/eval_nucleus_segmentation.py
new file mode 100644
index 0000000..98aa0e1
--- /dev/null
+++ b/segmentation/validation/eval_nucleus_segmentation.py
@@ -0,0 +1,29 @@
+import argparse
+import os
+from scripts.segmentation.validation import eval_nuclei
+
+ANNOTATIONS = '../../data/rawdata/evaluation/validation_annotations.h5'
+
+
+def eval_seg(path, key):
+    fp, fn, tot = eval_nuclei(path, key, ANNOTATIONS)
+    print("Evaluation yields:")
+    print("False positives:", fp)
+    print("False negatives:", fn)
+    print("Total number of annotations:", tot)
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("path", type=str, help="Path to nuclei segmentation that should be validated.")
+    parser.add_argument("--key", type=str, default="t00000/s00/0/cells", help="Segmentation key")
+    args = parser.parse_args()
+
+    path = args.path
+    key = args.key
+    assert os.path.exists(path), path
+    eval_seg(path, key)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/segmentation/validation/proofread_annotations.py b/segmentation/validation/proofread_annotations.py
new file mode 100644
index 0000000..0fd4e1a
--- /dev/null
+++ b/segmentation/validation/proofread_annotations.py
@@ -0,0 +1,29 @@
+import os
+from scripts.segmentation.validation.refine_annotations import refine, export_refined
+from scripts.segmentation.validation.eval_cells import get_ignore_seg_ids
+
+
+def proofread(orientation, slice_id):
+    seg_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.6.1',
+                            'segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5')
+    seg_key = 't00000/s00/0/cells'
+    table_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.6.1',
+                              'tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv')
+    ignore_seg_ids = get_ignore_seg_ids(table_path)
+
+    proj_folder = 'project_folders/proofread_%s_%i' % (orientation, slice_id)
+    refine(seg_path, seg_key, ignore_seg_ids,
+           orientation, slice_id, proj_folder)
+
+
+# orientation xy has slices:
+# 1000, 2000, 4000, 7000
+# orientation xz has slices:
+# 4000, 5998, 8662, 9328
+if __name__ == '__main__':
+    orientation = 'xz'
+    slice_id = 4000
+    proofread(orientation, slice_id)
+    # done:
+    # xy: 1000, 2000, 4000, 7000
+    # xz: 4000 (partially)
-- 
GitLab