From 0e93379d8c0b62df913a6825eaa9bd5b1835e2b9 Mon Sep 17 00:00:00 2001 From: Constantin Pape <c.pape@gmx.net> Date: Thu, 17 Oct 2019 20:50:07 +0200 Subject: [PATCH] Finish (initial) implementation of segmentation correction and validation. --- analysis/validation/check_validation.py | 67 --- analysis/validation/eval_cell_segmentation.py | 63 --- .../validation/eval_nucleus_segmentation.py | 0 scripts/attributes/region_attributes.py | 5 +- scripts/segmentation/correction/__init__.py | 8 + .../correction/annotation_tool.py | 213 +++++++++ .../correction/assignment_diffs.py | 54 +++ .../correction/cillia_correction_tool.py | 216 +++++++++ .../correction/correct_false_merges.py | 23 - .../correction/correction_tool.py | 441 ++++++++++++++++++ .../correction/export_node_labels.py | 177 +++++++ scripts/segmentation/correction/heuristics.py | 173 +++++++ .../correction/mark_false_merges.py | 20 - scripts/segmentation/correction/preprocess.py | 117 +++++ .../segmentation/validation/eval_nuclei.py | 74 ++- .../validation/evaluate_annotations.py | 4 +- .../validation/refine_annotations.py | 12 +- .../correction/correct_cell_segmentation.py | 175 +++++++ segmentation/correction/deprecated.py | 47 ++ segmentation/correction/fix_assignments.py | 69 +++ segmentation/correction/morphology_outlier.py | 53 +++ segmentation/validation/check_validation.py | 106 +++++ .../validation/eval_cell_segmentation.py | 94 ++++ .../validation/eval_nucleus_segmentation.py | 29 ++ .../validation/proofread_annotations.py | 29 ++ 25 files changed, 2082 insertions(+), 187 deletions(-) delete mode 100644 analysis/validation/check_validation.py delete mode 100644 analysis/validation/eval_cell_segmentation.py delete mode 100644 analysis/validation/eval_nucleus_segmentation.py create mode 100644 scripts/segmentation/correction/annotation_tool.py create mode 100644 scripts/segmentation/correction/assignment_diffs.py create mode 100644 scripts/segmentation/correction/cillia_correction_tool.py delete mode 100644 scripts/segmentation/correction/correct_false_merges.py create mode 100644 scripts/segmentation/correction/correction_tool.py create mode 100644 scripts/segmentation/correction/export_node_labels.py create mode 100644 scripts/segmentation/correction/heuristics.py delete mode 100644 scripts/segmentation/correction/mark_false_merges.py create mode 100644 scripts/segmentation/correction/preprocess.py create mode 100644 segmentation/correction/correct_cell_segmentation.py create mode 100644 segmentation/correction/deprecated.py create mode 100644 segmentation/correction/fix_assignments.py create mode 100644 segmentation/correction/morphology_outlier.py create mode 100644 segmentation/validation/check_validation.py create mode 100644 segmentation/validation/eval_cell_segmentation.py create mode 100644 segmentation/validation/eval_nucleus_segmentation.py create mode 100644 segmentation/validation/proofread_annotations.py diff --git a/analysis/validation/check_validation.py b/analysis/validation/check_validation.py deleted file mode 100644 index 3007aa6..0000000 --- a/analysis/validation/check_validation.py +++ /dev/null @@ -1,67 +0,0 @@ -import h5py -from heimdall import view, to_source -from scripts.segmentation.validation import eval_cells, eval_nuclei, get_ignore_seg_ids - - -def check_cell_evaluation(): - from scripts.segmentation.validation.eval_cells import (eval_slice, - get_bounding_box) - - praw = '../../data/rawdata/sbem-6dpf-1-whole-raw.h5' - pseg = '../../data/0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5' - pann = '../../data/rawdata/evaluation/validation_annotations.h5' - - table_path = '../../data/0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv' - ignore_seg_ids = get_ignore_seg_ids(table_path) - - with h5py.File(pseg, 'r') as fseg, h5py.File(pann, 'r') as fann: - ds_seg = fseg['t00000/s00/0/cells'] - ds_ann = fann['xy/1000'] - - print("Run evaluation ...") - res, masks = eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius=16, - return_masks=True) - fm, fs = masks['merges'], masks['splits'] - print() - print("Eval result") - print(res) - print() - - print("Load raw data ...") - bb = get_bounding_box(ds_ann) - with h5py.File(praw, 'r') as f: - raw = f['t00000/s00/1/cells'][bb].squeeze() - - print("Load seg data ...") - seg = ds_seg[bb].squeeze().astype('uint32') - - view(to_source(raw, name='raw'), to_source(seg, name='seg'), - to_source(fm, name='merges'), to_source(fs, name='splits')) - - -def eval_all_cells(): - pseg = '../../data/0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5' - pann = '../../data/rawdata/evaluation/validation_annotations.h5' - - table_path = '../../data/0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv' - ignore_seg_ids = get_ignore_seg_ids(table_path) - - res = eval_cells(pseg, 't00000/s00/0/cells', pann, - ignore_seg_ids=ignore_seg_ids) - print("Eval result:") - print(res) - - -# TODO -def check_nucleus_evaluation(): - pass - - -# TODO -def eval_all_nulcei(): - eval_nuclei() - - -if __name__ == '__main__': - # check_cell_evaluation() - eval_all_cells() diff --git a/analysis/validation/eval_cell_segmentation.py b/analysis/validation/eval_cell_segmentation.py deleted file mode 100644 index b8f6470..0000000 --- a/analysis/validation/eval_cell_segmentation.py +++ /dev/null @@ -1,63 +0,0 @@ -# TODO shebang -import argparse -import os -from scripts.segmentation.validation import eval_cells, get_ignore_seg_ids - -ANNOTATIONS = '../../data/rawdata/evaluation/validation_annotations.h5' -BASELINES = '../../data/rawdata/evaluation/baseline_cell_segmentations.h5' - - -def eval_seg(path, key, table): - ignore_ids = get_ignore_seg_ids(table) - fm, fs, tot = eval_cells(path, key, ANNOTATIONS, - ignore_seg_ids=ignore_ids) - print("Evaluation yields:") - print("False merges:", fm) - print("False splits:", fs) - print("Total number of annotations:", tot) - - -def eval_baselines(): - names = ['lmc', 'mc', 'curated_lmc', 'curated_mc'] - # TODO still need to compute region tables for the baselines - tables = ['', - '', - '', - ''] - results = {} - for name, table in zip(names, tables): - ignore_ids = get_ignore_seg_ids(table) - fm, fs, tot = eval_cells(path, key, ANNOTATIONS, - ignore_seg_ids=ignore_ids) - results[name] = (fm, fs, tot) - - for name in names: - print("Evaluation of", name, "yields:") - print("False merges:", fm) - print("False splits:", fs) - print("Total number of annotations:", tot) - - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("path", type=str, help="Path to segmentation that should be validated.") - parser.add_argument("table", type=str, help="Path to table with region/semantic assignments") - parse.add_argument("--key", type=str, default="t00000/s00/0/cells", help="Segmentation key") - parser.add_argument("--baselines", type=int, default=0, - help="Whether to evaluate the baseline segmentations (overrides path)") - args = parser.parse_args() - - baselines = bool(args.baselines) - if baselines: - eval_baselines() - else: - path = args.path - table = args.table - key = args.key - assert os.path.exists(path), path - eval_seg(path, key, table) - - -if __name__ == '__main__': - main() diff --git a/analysis/validation/eval_nucleus_segmentation.py b/analysis/validation/eval_nucleus_segmentation.py deleted file mode 100644 index e69de29..0000000 diff --git a/scripts/attributes/region_attributes.py b/scripts/attributes/region_attributes.py index 950836b..ae9ea9f 100644 --- a/scripts/attributes/region_attributes.py +++ b/scripts/attributes/region_attributes.py @@ -63,9 +63,8 @@ def muscle_attributes(muscle_path, key_muscle, def region_attributes(seg_path, region_out, image_folder, segmentation_folder, - label_ids, tmp_folder, target, max_jobs): - - key_seg = 't00000/s00/2/cells' + label_ids, tmp_folder, target, max_jobs, + key_seg='t00000/s00/2/cells'): key_tissue = 't00000/s00/0/cells' # 1.) compute the mapping to carved regions diff --git a/scripts/segmentation/correction/__init__.py b/scripts/segmentation/correction/__init__.py index e69de29..d83c810 100644 --- a/scripts/segmentation/correction/__init__.py +++ b/scripts/segmentation/correction/__init__.py @@ -0,0 +1,8 @@ +from .annotation_tool import AnnotationTool +from .correction_tool import CorrectionTool +from .cillia_correction_tool import CiliaCorrectionTool +from .marker_tool import MarkerTool + +from .preprocess import preprocess +from .export_node_labels import export_node_labels, to_paintera_format +from .heuristics import rank_false_merges, get_ignore_ids diff --git a/scripts/segmentation/correction/annotation_tool.py b/scripts/segmentation/correction/annotation_tool.py new file mode 100644 index 0000000..a0cfa3b --- /dev/null +++ b/scripts/segmentation/correction/annotation_tool.py @@ -0,0 +1,213 @@ +import os +import json +import queue +import sys +import threading + +import numpy as np +import napari + +from heimdall import view, to_source +from elf.io import open_file + + +# Annotate segments +class AnnotationTool: + n_threads = 3 + queue_len = 6 + + def __init__(self, project_folder, id_path, + table_path, table_key, scale_factor, + raw_path, raw_key, + ws_path, ws_key, + node_label_path=None, + node_label_key=None): + # + self.project_folder = project_folder + os.makedirs(self.project_folder, exist_ok=True) + self.processed_ids_file = os.path.join(project_folder, 'processed_ids.json') + self.annotation_path = os.path.join(project_folder, 'annotations.json') + + self.id_path = id_path + + self.table_path = table_path + self.table_key = table_key + self.scale_factor = scale_factor + + self.raw_path = raw_path + self.raw_key = raw_key + + self.ws_path = ws_path + self.ws_key = ws_key + + self.node_label_path = node_label_path + self.node_label_key = node_label_key + + self.load_all_data() + self.init_queue_and_workers() + print("Initialization done") + + def load_all_data(self): + self.ds_raw = open_file(self.raw_path)[self.raw_key] + self.ds_ws = open_file(self.ws_path)[self.ws_key] + self.shape = self.ds_raw.shape + assert self.ds_ws.shape == self.shape + + if self.node_label_path is None: + self.node_labels = None + else: + with open_file(self.node_label_path, 'r') as f: + self.node_labels = f[self.node_label_key][:] + + with open(self.id_path) as f: + self.ids = np.array(json.load(f)) + + if os.path.exists(self.processed_ids_file): + with open(self.processed_ids_file) as f: + self.processed_ids = json.load(f) + else: + self.processed_ids = [] + + already_processed = np.in1d(self.ids, self.processed_ids) + missing_ids = self.ids[~already_processed] + + if os.path.exists(self.annotation_path): + with open(self.annotation_path) as f: + self.annotations = json.load(f) + else: + self.annotations = {} + + self.next_queue = queue.Queue() + for mi in missing_ids: + self.next_queue.put_nowait(mi) + + # morphology table entries + # id (1) + # size (1) + # com (3) + # bb-min (3) + # bb-max (3) + with open_file(self.table_path, 'r') as f: + table = f[self.table_key][:] + self.bb_starts = table[:, 5:8] + self.bb_stops = table[:, 8:] + self.bb_starts /= self.scale_factor + self.bb_stops /= self.scale_factor + + def load_segment(self, seg_id): + + # get bounding box of this segment + starts, stops = self.bb_starts[seg_id], self.bb_stops[seg_id] + halo = [2, 2, 2] + bb = tuple(slice(max(0, int(sta - ha)), + min(sh, int(sto + ha))) + for sta, sto, ha, sh in zip(starts, stops, halo, self.shape)) + + # load raw and watershed + raw = self.ds_raw[bb] + ws = self.ds_ws[bb] + + # make the segment mask + if self.node_labels is None: + seg_mask = (ws == seg_id).astype('uint32') + ws = None + else: + node_ids = np.where(self.node_labels == seg_id)[0].astype('uint64') + seg_mask = np.isin(ws, node_ids) + ws[~seg_mask] = 0 + + return raw, ws, seg_mask.astype('uint32') + + def worker_thread(self): + while not self.next_queue.empty(): + seg_id = self.next_queue.get() + print("Loading seg", seg_id) + qitem = self.load_segment(seg_id) + self.queue.put((seg_id, qitem)) + + def init_queue_and_workers(self): + self.queue = queue.Queue(maxsize=self.queue_len) + for i in range(self.n_threads): + t = threading.Thread(name='worker-%i' % i, target=self.worker_thread) + t.setDaemon(True) + t.start() + + def correct_segment(self, seg_id, qitem): + print("Processing segment:", seg_id) + raw, ws, seg = qitem + annotation = None + + with napari.gui_qt(): + if ws is None: + viewer = view(to_source(raw, name='raw'), + to_source(seg, name='seg'), + return_viewer=True) + else: + viewer = view(to_source(raw, name='raw'), + to_source(ws, name='ws'), + to_source(seg, name='seg'), + return_viewer=True) + + @viewer.bind_key('h') + def print_help(viewer): + print("The following annotations are available:") + print("[r] - needs to be revisited") + print("[i] - incomplete, needs to be merged with other id") + print("[m] - merge with background") + print("[y] - confirm segment") + print("[c] - enter custom annotation") + print("Other options:") + print("[q] - quit") + + @viewer.bind_key('r') + def revisit(viewer): + nonlocal annotation + print("Setting annotation to revisit") + annotation = 'revisit' + + @viewer.bind_key('m') + def merge(viewer): + nonlocal annotation + print("Setting annotation to merge") + annotation = 'merge' + + @viewer.bind_key('i') + def incomplete(viewer): + nonlocal annotation + print("Setting annotation to incomplete") + annotation = 'incomplete' + + @viewer.bind_key('y') + def confirm(viewer): + nonlocal annotation + print("Setting annotation to confirm") + annotation = 'confirm' + + @viewer.bind_key('c') + def custom(viewer): + nonlocal annotation + annotation = input("Enter custom annotation") + + @viewer.bind_key('q') + def quit(viewer): + self.save_annotation(seg_id, annotation) + sys.exit(0) + + self.save_annotation(seg_id, annotation) + + def save_annotation(self, seg_id, annotation): + self.processed_ids.append(int(seg_id)) + with open(self.processed_ids_file, 'w') as f: + json.dump(self.processed_ids, f) + if annotation is None: + return + self.annotations[int(seg_id)] = annotation + with open(self.annotation_path, 'w') as f: + json.dump(self.annotations, f) + + def __call__(self): + left_to_process = len(self.ids) - len(self.processed_ids) + while left_to_process > 0: + seg_id, qitem = self.queue.get() + self.correct_segment(seg_id, qitem) + left_to_process = len(self.ids) - len(self.processed_ids) diff --git a/scripts/segmentation/correction/assignment_diffs.py b/scripts/segmentation/correction/assignment_diffs.py new file mode 100644 index 0000000..c85b046 --- /dev/null +++ b/scripts/segmentation/correction/assignment_diffs.py @@ -0,0 +1,54 @@ +import json +import os + +import numpy as np +import luigi +import nifty.ground_truth as ngt +from cluster_tools.node_labels import NodeLabelWorkflow + +# TODO some of this is general purpose and shoulg go to elf + + +def node_labels(ws_path, ws_key, input_path, input_key, + output_path, output_key, prefix, + tmp_folder, target='slurm', max_jobs=250): + task = NodeLabelWorkflow + + configs = task.get_config() + config_folder = os.path.join(tmp_folder, 'configs') + os.makedirs(config_folder, exist_ok=True) + + conf = configs['global'] + shebang = '/g/kreshuk/pape/Work/software/conda/miniconda3/envs/cluster_env37/bin/python3.7' + conf['shebang'] = shebang + conf['block_shape'] = [16, 256, 256] + + with open(os.path.join(config_folder, 'global.config'), 'w') as f: + json.dump(conf, f) + + conf = configs['block_node_labels'] + conf.update({'mem_limit': 4, 'time_limit': 120}) + with open(os.path.join(config_folder, 'block_node_labels.config'), 'w') as f: + json.dump(conf, f) + + t = task(tmp_folder=tmp_folder, config_dir=config_folder, + max_jobs=max_jobs, target=target, + ws_path=ws_path, ws_key=ws_key, + input_path=input_path, input_key=input_key, + output_path=output_path, output_key=output_key, + prefix=prefix) + luigi.build([t], local_scheduler=True) + + +# we only look at objects in the reference assignment that are SPLIT in +# the new assignment. For the other direction, just reverse reference and new +def assignment_diff_splits(reference_assignment, new_assignment): + assert reference_assignment.shape == new_assignment.shape + reference_ids = np.unique(reference_assignment) + + print("Computing assignment diff ...") + ovlp_comp = ngt.overlap(reference_assignment, new_assignment) + split_ids = [ovlp_comp.overlapArrays(ref_id)[0] for ref_id in reference_ids] + split_ids = {int(ref_id): len(ovlps) for ref_id, ovlps in zip(reference_ids, split_ids) + if len(ovlps) > 1} + return split_ids diff --git a/scripts/segmentation/correction/cillia_correction_tool.py b/scripts/segmentation/correction/cillia_correction_tool.py new file mode 100644 index 0000000..7c0bd7c --- /dev/null +++ b/scripts/segmentation/correction/cillia_correction_tool.py @@ -0,0 +1,216 @@ +# initially: +# go over all cilia, load the volume and highlight the cell the cilium was mapped to if applicable + +import json +import os +import queue +import sys +import threading + +import numpy as np +import pandas as pd +import napari + +from heimdall import view, to_source +from elf.io import open_file +from scripts.files.xml_utils import get_h5_path_from_xml + + +def xml_to_h5_path(xml_path): + path = get_h5_path_from_xml(xml_path, return_absolute_path=True) + return path + + +class CiliaCorrectionTool: + n_threads = 1 + queue_len = 2 + + def __init__(self, project_folder, + version_folder, scale, cilia_cell_table): + self.project_folder = project_folder + os.makedirs(self.project_folder, exist_ok=True) + self.processed_id_file = os.path.join(self.project_folder, 'processed_ids.json') + + assert os.path.exists(version_folder) + + raw_path = os.path.join(version_folder, 'images', 'sbem-6dpf-1-whole-raw.xml') + self.raw_path = xml_to_h5_path(raw_path) + + cilia_seg_path = os.path.join(version_folder, 'segmentations', + 'sbem-6dpf-1-whole-segmented-cilia-labels.xml') + self.cilia_seg_path = xml_to_h5_path(cilia_seg_path) + + cell_seg_path = os.path.join(version_folder, 'segmentations', + 'sbem-6dpf-1-whole-segmented-cells-labels.xml') + self.cell_seg_path = xml_to_h5_path(cell_seg_path) + + self.cilia_table_path = os.path.join(version_folder, 'tables', + 'sbem-6dpf-1-whole-segmented-cilia-labels', 'default.csv') + self.cilia_cell_table = pd.read_csv(cilia_cell_table, sep='\t') + + self.scale = scale + self.init_data() + self.init_queue_and_workers() + + def init_data(self): + # init the bounding boxes + table = pd.read_csv(self.cilia_table_path, sep='\t') + bb_start = table[['bb_min_z', 'bb_min_y', 'bb_min_x']].values.astype('float32') + bb_start[np.isinf(bb_start)] = 0 + bb_stop = table[['bb_max_z', 'bb_max_y', 'bb_max_x']].values.astype('float32') + + resolution = [.025, .01, .01] + scale_factor = [2 ** max(0, self.scale - 1)] + 2 * [2 ** self.scale] + resolution = [res * sf for res, sf in zip(resolution, scale_factor)] + + halo = [4, 32, 32] + self.bbs = [tuple(slice(int(sta / res - ha), int(sto / res + ha)) + for sta, sto, res, ha in zip(start, stop, resolution, halo)) + for start, stop in zip(bb_start, bb_stop)] + + # mapping from cilia ids to seg_ids + cil_map_ids = self.cilia_cell_table['cilia_id'].values + cell_map_ids = self.cilia_cell_table['cell_id'].values + self.id_mapping = {cil_id: cell_id for cil_id, cell_id in zip(cil_map_ids, cell_map_ids)} + + # init the relevant ids + self.cilia_ids = np.arange(len(self.bbs)) + if os.path.exists(self.processed_id_file): + with open(self.processed_id_file) as f: + self.processed_id_map = json.load(f) + else: + self.processed_id_map = {} + self.processed_id_map = {int(k): v for k, v in self.processed_id_map.items()} + self.processed_ids = list(self.processed_id_map.keys()) + + already_processed = np.in1d(self.cilia_ids, self.processed_ids) + missing_ids = self.cilia_ids[~already_processed] + + # fill the queue + self.next_queue = queue.Queue() + for mi in missing_ids: + # if mi in (0, 1): + # continue + self.next_queue.put_nowait(mi) + + def worker_thread(self): + while not self.next_queue.empty(): + seg_id = self.next_queue.get() + print("Loading seg", seg_id) + qitem = self.load_data(seg_id) + self.queue.put((seg_id, qitem)) + + def init_queue_and_workers(self): + self.queue = queue.Queue(maxsize=self.queue_len) + for i in range(self.n_threads): + t = threading.Thread(name='worker-%i' % i, target=self.worker_thread) + # t.setDaemon(True) + t.start() + save_folder = os.path.join(self.project_folder, 'results') + os.makedirs(save_folder, exist_ok=True) + + def load_data(self, cil_id): + if cil_id in (0, 1): + return None + cell_seg_key = 't00000/s00/%i/cells' % (self.scale - 1,) + cil_seg_key = 't00000/s00/%i/cells' % (self.scale + 1,) + raw_key = 't00000/s00/%i/cells' % self.scale + + bb = self.bbs[cil_id] + with open_file(self.raw_path, 'r') as f: + ds = f[raw_key] + raw = ds[bb] + + with open_file(self.cilia_seg_path, 'r') as f: + ds = f[cil_seg_key] + cil_seg = ds[bb].astype('uint32') + cil_mask = cil_seg == cil_id + cil_mask = 2 * cil_mask.astype('uint32') + + cell_id = self.id_mapping[cil_id] + if cell_id in (0, np.nan): + cell_seg = None + else: + + with open_file(self.cell_seg_path, 'r') as f: + ds = f[cell_seg_key] + cell_seg = ds[bb].astype('uint32') + cell_seg = (cell_seg == cell_id).astype('uint32') + + return raw, cil_seg, cil_mask, cell_seg + + def __call__(self): + left_to_process = len(self.cilia_ids) - len(self.processed_ids) + print("Left to process:", left_to_process) + while left_to_process > 0: + seg_id, qitem = self.queue.get() + self.correct_segment(seg_id, qitem) + left_to_process = len(self.cilia_ids) - len(self.processed_ids) + + def correct_segment(self, seg_id, qitem): + if qitem is None: + return + + print("Processing cilia:", seg_id) + raw, cil_seg, cil_mask, cell_seg = qitem + + with napari.gui_qt(): + if cell_seg is None: + viewer = view(to_source(raw, name='raw'), to_source(cil_seg, name='cilia-segmentation'), + to_source(cil_mask, name='cilia-mask'), return_viewer=True) + else: + viewer = view(to_source(raw, name='raw'), to_source(cil_seg, name='cilia-segmentation'), + to_source(cil_mask, name='cilia-mask'), to_source(cell_seg, name='cell-segmentation'), + return_viewer=True) + + @viewer.bind_key('c') + def confirm(viewer): + print("Confirming the current id", seg_id, "as correct") + self.processed_id_map[int(seg_id)] = 'correct' + + @viewer.bind_key('b') + def background(viewer): + print("Confirming the current id", seg_id, "into background") + self.processed_id_map[int(seg_id)] = 'background' + + @viewer.bind_key('m') + def merge(viewer): + print("Merging the current id", seg_id, "with other cilia") + valid_input = False + while not valid_input: + merge_id = input("Please enter the merge id:") + try: + merge_id = int(merge_id) + valid_input = True + except ValueError: + valid_input = False + print("You have entered an invalid input", merge_id, "please try again") + self.processed_id_map[int(seg_id)] = merge_id + + @viewer.bind_key('r') + def revisit(viewer): + print("Marking the current id", seg_id, "to be revisited because something is off") + self.processed_id_map[int(seg_id)] = 'revisit' + + @viewer.bind_key('h') + def print_help(viewer): + print("[c] - confirm cilia as correct") + print("[b] - mark cilia as background") + print("[m] - merge cilia with other cilia id") + print("[d] - revisit this cilia") + print("[q] - quit") + + # save progress and sys.exit + @viewer.bind_key('q') + def quit(viewer): + print("Quit correction tool") + self.save_result(seg_id) + sys.exit(0) + + # save the results for this segment + self.save_result(seg_id) + + def save_result(self, seg_id): + self.processed_ids.append(seg_id) + with open(self.processed_id_file, 'w') as f: + json.dump(self.processed_id_map, f) diff --git a/scripts/segmentation/correction/correct_false_merges.py b/scripts/segmentation/correction/correct_false_merges.py deleted file mode 100644 index 59e9e35..0000000 --- a/scripts/segmentation/correction/correct_false_merges.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -import napari -from elf.io import open_file - - -# TODO -# correct false merges via lifted multicut (or watershed) -# load raw data, watershed from bounding box for correction id -# load sub-graph for nodes corresponding to this segment -# (this takes long, so preload 5 or so via daemon process) -# add layer for seeds, resolve the segment via lmc (or watershed) -# once happy, store the new ids and move on to the next -class CorrectFalseMerges: - def __init__(self, project_folder, - correct_id_path=None, table_path=None, - raw_path=None, raw_key=None, - ws_path=None, ws_key=None, - node_label_path=None, node_label_key=None, - problem_path=None, graph_key=None, cost_key=None): - pass - - def __call__(self): - pass diff --git a/scripts/segmentation/correction/correction_tool.py b/scripts/segmentation/correction/correction_tool.py new file mode 100644 index 0000000..9f36a84 --- /dev/null +++ b/scripts/segmentation/correction/correction_tool.py @@ -0,0 +1,441 @@ +import os +import json +import queue +import sys +import threading + +import numpy as np +import napari +import nifty +import nifty.distributed as ndist +import nifty.tools as nt +import vigra + +from heimdall import view, to_source +from elf.io import open_file + + +# correct false merges via lifted multicut (or watershed) +# load raw data, watershed from bounding box for correction id +# load sub-graph for nodes corresponding to this segment +# (this takes long, so preload 5 or so via daemon process) +# add layer for seeds, resolve the segment via lmc (or watershed) +# once happy, store the new ids and move on to the next +class CorrectionTool: + n_threads = 3 + queue_len = 3 + + def __init__(self, project_folder, + false_merge_id_path=None, + table_path=None, table_key=None, scale_factor=None, + raw_path=None, raw_key=None, + seg_path=None, seg_key=None, + ws_path=None, ws_key=None, + node_label_path=None, node_label_key=None, + problem_path=None, graph_key=None, feat_key=None, + load_lazy=False): + # + self.project_folder = project_folder + self.config_file = os.path.join(project_folder, 'correct_false_merges_config.json') + self.processed_ids_file = os.path.join(project_folder, 'processed_ids.json') + self.annotation_path = os.path.join(project_folder, 'annotations.json') + + if os.path.exists(self.config_file): + self.read_config() + serialize_config = False + else: + assert false_merge_id_path is not None + self.false_merge_id_path = false_merge_id_path + + assert table_path is not None + self.table_path = table_path + assert table_key is not None + self.table_key = table_key + assert scale_factor is not None + self.scale_factor = scale_factor + + assert raw_path is not None + self.raw_path = raw_path + assert raw_key is not None + self.raw_key = raw_key + + assert seg_path is not None + self.seg_path = seg_path + assert seg_key is not None + self.seg_key = seg_key + + assert ws_path is not None + self.ws_path = ws_path + assert ws_key is not None + self.ws_key = ws_key + + assert node_label_path is not None + self.node_label_path = node_label_path + assert node_label_key is not None + self.node_label_key = node_label_key + + assert problem_path is not None + self.problem_path = problem_path + assert graph_key is not None + self.graph_key = graph_key + assert feat_key is not None + self.feat_key = feat_key + + serialize_config = True + + # TODO implement lazy loading + self.load_lazy = load_lazy + print("Loading graph, weights and node labels ...") + self.load_all_data() + print("... done") + + print("Initializing queues ...") + self.init_queue_and_workers() + print("... done") + + if serialize_config: + self.write_config() + print("Initialization done") + + def read_config(self): + with open(self.config_file) as f: + conf = json.load(f) + + self.false_merge_id_path = conf['false_merge_id_path'] + self.table_path = conf['table_path'] + self.table_key = conf['table_key'] + self.scale_factor = conf['scale_factor'] + + self.raw_path, self.raw_key = conf['raw_path'], conf['raw_key'] + self.seg_path, self.seg_key = conf['seg_path'], conf['seg_key'] + self.ws_path, self.ws_key = conf['ws_path'], conf['ws_key'] + self.node_label_path, self.node_label_key = conf['node_label_path'], conf['node_label_key'] + + self.problem_path = conf['problem_path'] + self.graph_key = conf['graph_key'] + self.feat_key = conf['feat_key'] + + def write_config(self): + os.makedirs(self.project_folder, exist_ok=True) + conf = {'false_merge_id_path': self.false_merge_id_path, + 'table_path': self.table_path, 'table_key': self.table_key, + 'scale_factor': self.scale_factor, + 'raw_path': self.raw_path, 'raw_key': self.raw_key, + 'seg_path': self.seg_path, 'seg_key': self.seg_key, + 'ws_path': self.ws_path, 'ws_key': self.ws_key, + 'node_label_path': self.node_label_path, 'node_label_key': self.node_label_key, + 'problem_path': self.problem_path, 'graph_key': self.graph_key, + 'feat_key': self.feat_key} + with open(self.config_file, 'w') as f: + json.dump(conf, f) + + def load_all_data(self): + self.ds_raw = open_file(self.raw_path)[self.raw_key] + self.ds_seg = open_file(self.seg_path)[self.seg_key] + self.ds_ws = open_file(self.ws_path)[self.ws_key] + self.shape = self.ds_raw.shape + assert self.ds_ws.shape == self.shape + + with open_file(self.node_label_path, 'r') as f: + self.node_labels = f[self.node_label_key][:] + + with open(self.false_merge_id_path) as f: + self.false_merge_ids = np.array(json.load(f)) + + if os.path.exists(self.processed_ids_file): + with open(self.processed_ids_file) as f: + self.processed_ids = json.load(f) + else: + self.processed_ids = [] + + already_processed = np.in1d(self.false_merge_ids, self.processed_ids) + missing_ids = self.false_merge_ids[~already_processed] + + if os.path.exists(self.annotation_path): + with open(self.annotation_path) as f: + self.annotations = json.load(f) + else: + self.annotations = {} + + self.next_queue = queue.Queue() + for mi in missing_ids: + self.next_queue.put_nowait(mi) + + self.graph = ndist.Graph(self.problem_path, self.graph_key, self.n_threads) + self.uv_ids = self.graph.uvIds() + with open_file(self.problem_path, 'r') as f: + ds = f[self.feat_key] + ds.n_threads = self.n_threads + + self.probs = ds[:, 0] + + # morphology table entries + # id (1) + # size (1) + # com (3) + # bb-min (3) + # bb-max (3) + with open_file(self.table_path, 'r') as f: + table = f[self.table_key][:] + self.bb_starts = table[:, 5:8] + self.bb_stops = table[:, 8:] + self.bb_starts /= self.scale_factor + self.bb_stops /= self.scale_factor + + def load_subgraph(self, node_ids): + inner_edges, _ = self.graph.extractSubgraphFromNodes(node_ids, allowInvalidNodes=True) + nodes_relabeled, max_id, mapping = vigra.analysis.relabelConsecutive(node_ids, + start_label=0, + keep_zeros=False) + uv_ids = self.uv_ids[inner_edges] + uv_ids = nt.takeDict(mapping, uv_ids) + n_nodes = max_id + 1 + graph = nifty.graph.undirectedGraph(n_nodes) + graph.insertEdges(uv_ids) + + probs = self.probs[inner_edges] + assert len(probs) == graph.numberOfEdges + return graph, probs, mapping + + def load_segment(self, seg_id): + + # get bounding box of this segment + starts, stops = self.bb_starts[seg_id], self.bb_stops[seg_id] + halo = [2, 2, 2] + bb = tuple(slice(max(0, int(sta - ha)), + min(sh, int(sto + ha))) + for sta, sto, ha, sh in zip(starts, stops, halo, self.shape)) + + # extract the sub-graph + node_ids = np.where(self.node_labels == seg_id)[0].astype('uint64') + graph, probs, mapping = self.load_subgraph(node_ids) + # graph, probs, mapping = None, None, None + + # load raw and watershed + raw = self.ds_raw[bb] + try: + seg = self.ds_seg[bb] + except RuntimeError: + seg = None + ws = self.ds_ws[bb] + + # make the segment mask + ws[~np.isin(ws, node_ids)] = 0 + if seg is None: + seg_mask = ws > 0 + else: + seg_mask = seg == seg_id + + return raw, ws, seg_mask.astype('uint32'), graph, probs, mapping + + def worker_thread(self): + while not self.next_queue.empty(): + seg_id = self.next_queue.get() + print("Loading seg", seg_id) + qitem = self.load_segment(seg_id) + self.queue.put((seg_id, qitem)) + + def init_queue_and_workers(self): + self.queue = queue.Queue(maxsize=self.queue_len) + for i in range(self.n_threads): + t = threading.Thread(name='worker-%i' % i, target=self.worker_thread) + t.setDaemon(True) + t.start() + save_folder = os.path.join(self.project_folder, 'results') + os.makedirs(save_folder, exist_ok=True) + + @staticmethod + def graph_watershed(graph, probs, ws, seed_points, mapping): + seed_ids = np.unique(seed_points)[1:] + if len(seed_ids) == 0: + return None + + seeds = np.zeros(graph.numberOfNodes, dtype='uint64') + for seed_id in seed_ids: + mask = seed_points == seed_id + seed_nodes = np.unique(ws[mask]) + if seed_nodes[0] == 0: + seed_nodes = seed_nodes[1:] + seed_nodes = nt.takeDict(mapping, seed_nodes) + seeds[seed_nodes] = seed_id + + node_labels = nifty.graph.edgeWeightedWatershedsSegmentation(graph, seeds, probs) + return node_labels + + def correct_segment(self, seg_id, qitem): + print("Processing segment:", seg_id) + raw, ws, seg, graph, probs, mapping = qitem + ws_ids = np.unique(ws)[1:] + node_labels = np.ones(graph.numberOfNodes, dtype='uint64') + + seeds = np.zeros_like(ws, dtype='uint32') + skip_this_segment = False + + with napari.gui_qt(): + viewer = view(to_source(raw, name='raw'), to_source(ws, name='ws'), + to_source(seg, name='seg'), to_source(seeds, name='seeds'), + return_viewer=True) + + @viewer.bind_key('h') + def print_help(viewer): + print("[w] - run watershed from seeds") + print("[s] - skip the current segment (if it's not a merge)") + print("[c] - clear seeds") + print("[d] - save current data for debugging") + print("[a] - annotate") + print("[q] - quit") + + @viewer.bind_key('a') + def annotatate(viewer): + msg = """Annotating seg id, choose one of [c] (correct) [s] (revisit at lower scale), [m] (merge to background). + Otherwise, you can add a custom annotation.""" + inp = input(msg) + if inp == 'c': + annotation = 'correct' + elif inp == 's': + annotation = 'revisit' + elif inp == 'm': + annotation = 'merge' + else: + annotation = input("Enter custom annotation.") + self.annotations[int(seg_id)] = annotation + with open(self.annotation_path, 'w') as f: + json.dump(self.annotations, f) + + @viewer.bind_key('s') + def skip(viewer): + print("Skipping the current segment") + nonlocal skip_this_segment + skip_this_segment = True + # TODO quit viewer + + @viewer.bind_key('d') + def debug(viewer): + print("Saving debug data ...") + debug_folder = os.path.join(self.project_folder, 'debug') + os.makedirs(debug_folder, exist_ok=True) + + layers = viewer.layers + seed_points = layers['seeds'].data + with open_file(os.path.join(debug_folder, 'data.n5')) as f: + f.create_dataset('raw', data=raw, compression='gzip') + f.create_dataset('ws', data=ws, compression='gzip') + f.create_dataset('seeds', data=seed_points, compression='gzip') + + with open(os.path.join(debug_folder, 'mapping.json'), 'w') as f: + json.dump(mapping, f) + np.save(os.path.join(debug_folder, 'graph.npy'), graph.uvIds()) + np.save(os.path.join(debug_folder, 'probs.npy'), probs) + print("... done") + + @viewer.bind_key('w') + def watershed(viewer): + nonlocal node_labels, seeds + + print("Run watershed from seed layer ...") + layers = viewer.layers + ws = layers['ws'].data + seeds = layers['seeds'].data + + new_node_labels = self.graph_watershed(graph, probs, ws, seeds, mapping) + if new_node_labels is None: + print("Did not find any seeds, doing nothing") + return + else: + node_labels = new_node_labels + + label_dict = {wsid: node_labels[mapping[wsid]] for wsid in ws_ids} + label_dict[0] = 0 + seg = nt.takeDict(label_dict, ws) + + layers['seg'].data = seg + print("... done") + + @viewer.bind_key('c') + def clear(viewer): + nonlocal node_labels, seeds + print("Clear seeds ...") + confirm = input("Do you really want to clean the seeds? y / [n]") + if confirm != 'y': + return + node_labels = np.ones(graph.numberOfNodes, dtype='uint64') + seeds = np.zeros_like(ws, dtype='uint32') + viewer.layers['seeds'].data = seeds + seg = (ws > 0).astype('uint32') + viewer.layers['seg'].data = seg + print("... done") + + # save progress and sys.exit + @viewer.bind_key('q') + def quit(viewer): + print("Quit correction tool") + self.save_segment_result(seg_id, ws, seeds, + node_labels, mapping, skip_this_segment) + sys.exit(0) + + # save the results for this segment + self.save_segment_result(seg_id, ws, seeds, node_labels, mapping, skip_this_segment) + + def save_segment_result(self, seg_id, ws, seeds, node_labels, mapping, skip): + if not skip: + save_file = os.path.join(self.project_folder, 'results', '%i.npz' % seg_id) + node_ids = list(mapping.keys()) + save_labels = [node_labels[mapping[nid]] for nid in node_ids] + + seed_ids = np.unique(seeds[1:]) + seeded_ids = [] + seed_labels = [] + for seed_id in seed_ids: + mask = seeds == seed_id + this_ids = np.unique(ws[mask]) + if this_ids[0] == 0: + this_ids = this_ids[1:] + seeded_ids.extend(this_ids) + seed_labels.extend(len(this_ids) * [seed_id]) + + np.savez_compressed(save_file, node_ids=node_ids, node_labels=save_labels, + seeded_ids=seeded_ids, seed_labels=seed_labels) + self.processed_ids.append(int(seg_id)) + with open(self.processed_ids_file, 'w') as f: + json.dump(self.processed_ids, f) + + def __call__(self): + left_to_process = len(self.false_merge_ids) - len(self.processed_ids) + while left_to_process > 0: + seg_id, qitem = self.queue.get() + self.correct_segment(seg_id, qitem) + left_to_process = len(self.false_merge_ids) - len(self.processed_ids) + + @staticmethod + def debug(debug_folder, n_threads=8): + with open_file(os.path.join(debug_folder, 'data.n5')) as f: + ds = f['raw'] + ds.n_threads = n_threads + raw = ds[:] + + ds = f['ws'] + ds.n_threads = n_threads + ws = ds[:] + + ds = f['seeds'] + ds.n_threads = n_threads + seed_points = ds[:] + + with open(os.path.join(debug_folder, 'mapping.json'), 'r') as f: + mapping = json.load(f) + mapping = {int(k): v for k, v in mapping.items()} + + uv_ids = np.load(os.path.join(debug_folder, 'graph.npy')) + n_nodes = int(uv_ids.max()) + 1 + graph = nifty.graph.undirectedGraph(n_nodes) + graph.insertEdges(uv_ids) + probs = np.load(os.path.join(debug_folder, 'probs.npy')) + + node_labels = CorrectionTool.graph_watershed(graph, probs, ws, seed_points, mapping) + + ws_ids = np.unique(ws)[1:] + label_dict = {wsid: node_labels[mapping[wsid]] for wsid in ws_ids} + label_dict[0] = 0 + seg = nt.takeDict(label_dict, ws) + + view(raw, ws, seg) diff --git a/scripts/segmentation/correction/export_node_labels.py b/scripts/segmentation/correction/export_node_labels.py new file mode 100644 index 0000000..7f3b4c3 --- /dev/null +++ b/scripts/segmentation/correction/export_node_labels.py @@ -0,0 +1,177 @@ +import os +from math import ceil, floor + +import numpy as np +import vigra +from elf.io import open_file + + +def export_node_labels(path, in_key, out_key, project_folder): + with open_file(path, 'r') as f: + ds = f[in_key] + node_labels = ds[:] + + result_folder = os.path.join(project_folder, 'results') + res_files = os.listdir(result_folder) + + print("Applying changes for", len(res_files), "resolved objects") + + id_offset = int(node_labels.max()) + 1 + for resf in res_files: + seg_id = int(os.path.splitext(resf)[0]) + resf = os.path.join(result_folder, resf) + res = np.load(resf) + + this_ids, this_labels = res['node_ids'], res['node_labels'] + assert len(this_ids) == len(this_labels) + this_ids_exp = np.where(node_labels == seg_id)[0] + assert np.array_equal(np.sort(this_ids), this_ids_exp) + + this_labels += id_offset + node_labels[this_ids] = this_labels + id_offset = int(this_labels.max()) + 1 + + with open_file(path) as f: + chunks = (min(int(1e6), len(node_labels)),) + ds = f.require_dataset(out_key, compression='gzip', dtype=node_labels.dtype, + chunks=chunks, shape=node_labels.shape) + ds[:] = node_labels + + +def to_paintera_format(in_path, in_key, out_path, out_key): + with open_file(in_path, 'r') as f: + node_labels = f[in_key][:] + node_labels = vigra.analysis.relabelConsecutive(node_labels, start_label=1, keep_zeros=True)[0] + + n_ws = len(node_labels) + ws_ids = np.arange(n_ws, dtype='uint64') + assert len(ws_ids) == len(node_labels) + + seg_ids, seg_counts = np.unique(node_labels, return_counts=True) + trivial_segments = seg_ids[seg_counts == 1] + trivial_mask = np.in1d(node_labels, trivial_segments) + + ws_ids = ws_ids[~trivial_mask] + node_labels = node_labels[~trivial_mask] + + node_labels[node_labels != 0] += n_ws + + max_id = node_labels.max() + print("new max id:", max_id) + + paintera_labels = np.concatenate([ws_ids[:, None], node_labels[:, None]], axis=1).T + print(paintera_labels.shape) + + with open_file(out_path) as f: + chunks = (1, min(paintera_labels.shape[1], int(1e6))) + f.create_dataset(out_key, data=paintera_labels, chunks=chunks, + compression='gzip') + + +def zero_out_ids(node_label_in_path, node_label_in_key, + node_label_out_path, node_label_out_key, + zero_ids): + with open_file(node_label_in_path, 'r') as f: + ds = f[node_label_in_key] + node_labels = ds[:] + chunks = ds.chunks + + zero_mask = np.isin(node_labels, zero_ids) + node_labels[zero_mask] = 0 + + with open_file(node_label_out_path) as f: + ds = f.require_dataset(node_label_out_key, shape=node_labels.shape, chunks=chunks, + compression='gzip', dtype=node_labels.dtype) + ds[:] = node_labels + + +# TODO refactor this properly +def get_bounding_boxes(table_path, table_key, scale_factor): + with open_file(table_path, 'r') as f: + table = f[table_key][:] + bb_starts = table[:, 5:8] + bb_stops = table[:, 8:] + bb_starts /= scale_factor + bb_stops /= scale_factor + bounding_boxes = [tuple(slice(int(floor(sta)), + int(ceil(sto))) for sta, sto in zip(start, stop)) + for start, stop in zip(bb_starts, bb_stops)] + return bounding_boxes + + +def check_exported(res_file, raw_path, raw_key, ws_path, ws_key, + table_path, table_key, scale_factor): + from heimdall import view + import nifty.tools as nt + seg_id = int(os.path.splitext(os.path.split(res_file)[1])[0]) + res = np.load(res_file) + + bb = get_bounding_boxes(table_path, table_key, scale_factor)[seg_id] + + node_ids, node_labels = res['node_ids'], res['node_labels'] + assert len(node_ids) == len(node_labels) + + with open_file(raw_path, 'r') as f: + ds = f[raw_key] + ds.n_threads = 8 + raw = ds[bb] + with open_file(ws_path, 'r') as f: + ds = f[ws_key] + ds.n_threads = 8 + ws = ds[bb] + + seg_mask = np.isin(ws, node_ids) + ws[~seg_mask] = 0 + label_dict = {wsid: lid for wsid, lid in zip(node_ids, node_labels)} + label_dict[0] = 0 + seg = nt.takeDict(label_dict, ws) + + view(raw, seg) + + +def check_exported_paintera(paintera_path, assignment_key, + node_label_path, node_label_key, + table_path, table_key, scale_factor, + raw_path, raw_key, seg_path, seg_key, + check_ids): + from heimdall import view + import nifty.tools as nt + + with open_file(paintera_path, 'r') as f: + ds = f[assignment_key] + new_assignments = ds[:].T + + with open_file(node_label_path, 'r') as f: + ds = f[node_label_key] + node_labels = ds[:] + + bounding_boxes = get_bounding_boxes(table_path, table_key, scale_factor) + with open_file(seg_path, 'r') as fseg, open_file(raw_path, 'r') as fraw: + ds_seg = fseg[seg_key] + ds_seg.n_thread = 8 + ds_raw = fraw[raw_key] + ds_raw.n_thread = 8 + + for seg_id in check_ids: + bb = bounding_boxes[seg_id] + raw = ds_raw[bb] + ws = ds_seg[bb] + + ws_ids = np.where(node_labels == seg_id)[0] + seg_mask = np.isin(ws, ws_ids) + ws[~seg_mask] = 0 + + new_label_mask = np.isin(new_assignments[:, 0], ws_ids) + new_label_dict = dict(zip(new_assignments[:, 0][new_label_mask], + new_assignments[:, 1][new_label_mask])) + new_label_dict[0] = 0 + + # I am not sure why this happens + un_ws = np.unique(ws) + un_labels = list(new_label_dict.keys()) + missing = np.setdiff1d(un_ws, un_labels) + print("Number of missing: ") + new_label_dict.update({miss: 0 for miss in missing}) + + seg_new = nt.takeDict(new_label_dict, ws) + view(raw, seg_mask.astype('uint32'), seg_new) diff --git a/scripts/segmentation/correction/heuristics.py b/scripts/segmentation/correction/heuristics.py new file mode 100644 index 0000000..8a16ecf --- /dev/null +++ b/scripts/segmentation/correction/heuristics.py @@ -0,0 +1,173 @@ +import json +from concurrent import futures +from math import ceil, floor + +import numpy as np +import nifty.distributed as ndist +import vigra + +from scipy.ndimage.morphology import binary_closing, binary_opening, binary_erosion +from skimage.morphology import convex_hull_image +from elf.io import open_file + + +# +# heuristics to find morphology outliers +# + + +def closing_ratio(seg, n_iters): + seg_closed = binary_closing(seg, iterations=n_iters) + m1 = float(seg.sum()) + m2 = float(seg_closed.sum()) + return m2 / m1 + + +def opening_ratio(seg, n_iters): + seg_opened = binary_opening(seg, iterations=n_iters) + m1 = float(seg.sum()) + m2 = float(seg_opened.sum()) + return m1 / m2 + + +def convex_hull_ratio(seg): + seg_conv = convex_hull_image(seg) + m1 = float(seg.sum()) + m2 = float(seg_conv.sum()) + return m2 / m1 + + +def components_per_slice(seg, n_iters=0): + n_components = 0 + n_slices = seg.shape[0] + if n_slices == 0: + return 0. + + for z in range(n_slices): + if n_iters > 0: + segz = binary_erosion(seg[z], iterations=n_iters).astype('uint32') + else: + segz = seg[z].astype('uint32') + n_components += len(np.unique(vigra.analysis.labelImageWithBackground(segz))[1:]) + n_components /= n_slices + return n_components + + +def read_bb_from_table(table_path, table_key, scale_factor): + with open_file(table_path, 'r') as f: + table = f[table_key][:] + bb_starts = table[:, 5:8] / scale_factor + bb_stops = table[:, 8:] / scale_factor + bbs = [tuple(slice(int(floor(sta)), + int(ceil(sto))) for sta, sto in zip(starts, stops)) + for starts, stops in zip(bb_starts, bb_stops)] + return bbs + + +def compute_ratios(seg_path, seg_key, table_path, table_key, + scale_factor, n_threads, + compute_ratio, seg_ids=None, sort=False): + with open_file(seg_path, 'r') as f: + ds = f[seg_key] + bounding_boxes = read_bb_from_table(table_path, table_key, scale_factor) + + if seg_ids is None: + seg_ids = np.arange(len(bounding_boxes)) + n_seg_ids = len(seg_ids) + + def _compute_ratio(seg_id): + print("%i / %i" % (seg_id, n_seg_ids)) + bb = bounding_boxes[seg_id] + seg = ds[bb] + seg = seg == seg_id + ratio = compute_ratio(seg) + return ratio + + with futures.ThreadPoolExecutor(n_threads) as tp: + tasks = [tp.submit(_compute_ratio, seg_id) for seg_id in seg_ids] + ratios = np.array([t.result() for t in tasks]) + + if sort: + sorted_ids = np.argsort(ratios)[::-1] + seg_ids = seg_ids[sorted_ids] + ratios = ratios[sorted_ids] + + return seg_ids, ratios + + +# +# heuristics to find likely false merges +# + +def get_ignore_ids(label_path, label_key, + ignore_names=['yolk', 'cuticle', 'neuropil']): + with open_file(label_path, 'r') as f: + ds = f[label_key] + semantics = ds.attrs['semantics'] + labels = ds[:] + ignore_label_ids = [] + for name, ids in semantics.items(): + if name in ignore_names: + ignore_label_ids.extend(ids) + ignore_ids = np.isin(labels, ignore_label_ids) + ignore_ids = np.where(ignore_ids)[0] + return ignore_ids + + +def weight_quantile_heuristic(seg_id, graph, node_labels, sizes, max_size, weights, + quantile=90): + size_ratio = float(sizes[seg_id]) / max_size + node_ids = np.where(node_labels == seg_id)[0] + edges = graph.extractSubgraphFromNodes(node_ids, allowInvalidNodes=True)[0] + this_weights = weights[edges] + try: + score = np.percentile(this_weights, quantile) * size_ratio + except IndexError: + print("Something went wrong", seg_id, this_weights.shape) + score = 0. + return score + + +def rank_false_merges(problem_path, graph_key, feat_key, + morpho_key, node_label_path, node_label_key, + ignore_ids, out_path_ids, out_path_scores, + n_threads, n_candidates, heuristic=weight_quantile_heuristic): + g = ndist.Graph(problem_path, graph_key, n_threads) + with open_file(problem_path, 'r') as f: + ds = f[feat_key] + ds.n_threads = n_threads + probs = ds[:, 0] + + ds = f[morpho_key] + ds.n_threads = n_threads + sizes = ds[:, 1] + + with open_file(node_label_path, 'r') as f: + ds = f[node_label_key] + ds.n_threads = n_threads + node_labels = ds[:] + + seg_ids = np.arange(len(sizes), dtype='uint64') + seg_ids = seg_ids[np.argsort(sizes)[::-1]][:n_candidates] + seg_ids = seg_ids[~np.isin(seg_ids, ignore_ids.tolist() + [0])] + max_size = sizes[seg_ids].max() + with futures.ThreadPoolExecutor(n_threads) as tp: + tasks = [tp.submit(weight_quantile_heuristic, seg_id, g, + node_labels, sizes, max_size, probs) for seg_id in seg_ids] + fm_scores = np.array([t.result() for t in tasks]) + + # print("Id:", seg_ids[0]) + # sc = weight_quantile_heuristic(seg_ids[0], g, + # node_labels, sizes, max_size, probs) + # print("Score:", sc) + # return + + # sort ids by score (decreasing) + sorter = np.argsort(fm_scores)[::-1] + seg_ids = seg_ids[sorter] + fm_scores = fm_scores[sorter] + + with open(out_path_scores, 'w') as f: + json.dump(fm_scores.tolist(), f) + with open(out_path_ids, 'w') as f: + json.dump(seg_ids.tolist(), f) diff --git a/scripts/segmentation/correction/mark_false_merges.py b/scripts/segmentation/correction/mark_false_merges.py deleted file mode 100644 index e05bfc7..0000000 --- a/scripts/segmentation/correction/mark_false_merges.py +++ /dev/null @@ -1,20 +0,0 @@ -import os -import json -import napari -from elf.io import open_file - - -# TODO -# mark false merges in segmentation: -# open napari viewer with raw data and segmentation (in appropriate res) -# set annotations in 'has_false_merge' and 'is_correct' layer -# clear segmentations that have an annotation (upon key press) -# store annotations in the project folder -class MarkFalseMerges: - def __init__(self, project_folder, - raw_path=None, raw_key=None, - seg_path=None, seg_key=None): - pass - - def __call__(self): - pass diff --git a/scripts/segmentation/correction/preprocess.py b/scripts/segmentation/correction/preprocess.py new file mode 100644 index 0000000..0522800 --- /dev/null +++ b/scripts/segmentation/correction/preprocess.py @@ -0,0 +1,117 @@ +import os +import json +import luigi +from elf.io import open_file + +from paintera_tools import serialize_from_commit +from paintera_tools.util import compute_graph_and_weights +from cluster_tools.node_labels import NodeLabelWorkflow +from cluster_tools.morphology import MorphologyWorkflow +from cluster_tools.downscaling import DownscalingWorkflow + + +def graph_and_costs(path, ws_key, aff_key, out_path): + tmp_folder = './tmp_preprocess' + compute_graph_and_weights(path, aff_key, + path, ws_key, out_path, + tmp_folder, target='slurm', max_jobs=250, + offsets=[[-1, 0, 0], [0, -1, 0], [0, 0, -1]], + with_costs=True) + + +def accumulate_node_labels(ws_path, ws_key, seg_path, seg_key, + out_path, out_key, prefix): + task = NodeLabelWorkflow + + tmp_folder = './tmp_preprocess' + config_dir = os.path.join(tmp_folder, 'configs') + + t = task(tmp_folder=tmp_folder, config_dir=config_dir, + target='slurm', max_jobs=250, + ws_path=ws_path, ws_key=ws_key, + input_path=seg_path, input_key=seg_key, + output_path=out_path, output_key=out_key, + prefix=prefix) + ret = luigi.build([t], local_scheduler=True) + assert ret + + +def compute_bounding_boxes(path, key): + task = MorphologyWorkflow + tmp_folder = './tmp_preprocess' + config_dir = os.path.join(tmp_folder, 'configs') + + out_key = 'morphology' + t = task(tmp_folder=tmp_folder, config_dir=config_dir, + target='slurm', max_jobs=250, + input_path=path, input_key=key, + output_path=path, output_key=out_key) + ret = luigi.build([t], local_scheduler=True) + assert ret + + +def downscale_segmentation(path, key): + task = DownscalingWorkflow + + tmp_folder = './tmp_preprocess' + config_dir = os.path.join(tmp_folder, 'configs') + + configs = task.get_config() + conf = configs['downscaling'] + conf.update({'library_kwargs': {'order': 0}}) + with open(os.path.join(config_dir, 'downscaling.config'), 'w') as f: + json.dump(conf, f) + + in_key = os.path.join(key, 's0') + n_scales = 5 + scales = n_scales * [[2, 2, 2]] + halos = n_scales * [[0, 0, 0]] + + t = task(tmp_folder=tmp_folder, config_dir=config_dir, + # target='slurm', max_jobs=250, + target='local', max_jobs=64, + input_path=path, input_key=in_key, + scale_factors=scales, halos=halos, + output_path=path, output_key_prefix=key) + ret = luigi.build([t], local_scheduler=True) + assert ret + + +def copy_tissue_labels(in_path, out_path, out_key): + with open_file(in_path, 'r') as f: + names = f['semantic_names'][:] + mapping = f['semantic_mapping'][:] + + semantics = {name: ids.tolist() for name, ids in zip(names, mapping)} + with open_file(out_path) as f: + ds = f[out_key] + ds.attrs['semantics'] = semantics + + +# preprocess: +# - export current paintera segmentation +# - build graph and compute weights for current superpixels +# - get current node labeling +# - compute bounding boxes for current segments +# - downscale the segmentation +def preprocess(path, key, aff_key, + tissue_path, tissue_key, + out_path, out_key): + tmp_folder = './tmp_preprocess' + out_key0 = os.path.join(out_key, 's0') + + serialize_from_commit(path, key, out_path, out_key0, tmp_folder, + max_jobs=250, target='slurm', relabel_output=True) + + ws_key = os.path.join(key, 'data', 's0') + graph_and_costs(path, ws_key, aff_key, out_path) + + accumulate_node_labels(path, ws_key, out_path, out_key0, + out_path, 'node_labels', prefix='node_labels') + + accumulate_node_labels(out_path, out_key0, tissue_path, tissue_key, + out_path, 'tissue_labels', prefix='tissue') + copy_tissue_labels(tissue_path, out_path, 'tissue_labels') + + compute_bounding_boxes(out_path, out_key0) + downscale_segmentation(out_path, out_key) diff --git a/scripts/segmentation/validation/eval_nuclei.py b/scripts/segmentation/validation/eval_nuclei.py index 8b56dc4..86b985f 100644 --- a/scripts/segmentation/validation/eval_nuclei.py +++ b/scripts/segmentation/validation/eval_nuclei.py @@ -1,6 +1,72 @@ -from .evaluate_annotations import evaluate_annotations +from math import ceil, floor +import vigra +from elf.io import open_file, is_dataset +from .evaluate_annotations import evaluate_annotations, merge_evaluations -# TODO -def eval_nuclei(): - pass +def get_bounding_box(ds, scale_factor): + attrs = ds.attrs + start, stop = attrs['starts'], attrs['stops'] + bb = tuple(slice(int(floor(sta / scale_factor)), + int(ceil(sto / scale_factor))) for sta, sto in zip(start, stop)) + return bb + + +def to_scores(eval_res): + n = float(eval_res['n_annotations'] - eval_res['n_unmatched']) + fp = eval_res['n_splits'] + fn = eval_res['n_merged_annotations'] + return fp / n, fn / n, n + + +# we may want to do a max projection of some z context ?! +def get_nucleus_segmentation(ds_seg, bb): + seg = ds_seg[bb].squeeze().astype('uint32') + return seg + + +# need to downsample the annotations and bounding box to fit the +# nucleus segmentation +def eval_slice(ds_seg, ds_ann, min_radius, return_masks=False): + ds_seg.n_threads = 8 + ds_ann.n_threads = 8 + + bb = get_bounding_box(ds_ann, scale_factor=4.) + annotations = ds_ann[:] + seg = get_nucleus_segmentation(ds_seg, bb) + annotations = vigra.sampling.resize(annotations.astype('float32'), + shape=seg.shape, order=0).astype('uint32') + + fg_annotations = (annotations == 1).astype('uint32') + bg_annotations = None + + return evaluate_annotations(seg, fg_annotations, bg_annotations, + min_radius=min_radius, return_masks=return_masks) + + +def eval_nuclei(seg_path, seg_key, + annotation_path, annotation_key=None, + min_radius=6): + """ Evaluate the nucleus segmentation by computing + the percentage of false positive and false negative nucleus annotations + in manually annotated validation slices. + """ + eval_res = {} + with open_file(seg_path, 'r') as f_seg, open_file(annotation_path, 'r') as f_ann: + ds_seg = f_seg[seg_key] + g = f_ann if annotation_key is None else f_ann[annotation_key] + + def visit_annotation(name, node): + nonlocal eval_res + if is_dataset(node): + print("Evaluating:", name) + res = eval_slice(ds_seg, node, min_radius) + eval_res = merge_evaluations(res, eval_res) + # for debugging + # print("current eval:", eval_res) + else: + print("Group:", name) + + g.visititems(visit_annotation) + + return to_scores(eval_res) diff --git a/scripts/segmentation/validation/evaluate_annotations.py b/scripts/segmentation/validation/evaluate_annotations.py index 4fda835..ab4ffd0 100644 --- a/scripts/segmentation/validation/evaluate_annotations.py +++ b/scripts/segmentation/validation/evaluate_annotations.py @@ -22,7 +22,7 @@ def get_radii(seg): return radii -def evaluate_annotations(seg, fg_annotations, bg_annotations, +def evaluate_annotations(seg, fg_annotations, bg_annotations=None, ignore_seg_ids=None, min_radius=16, return_masks=False, return_ids=False): """ Evaluate segmentation based on evaluations. @@ -54,7 +54,7 @@ def evaluate_annotations(seg, fg_annotations, bg_annotations, # check if this is an ignore id and skip if ignore_seg_ids is not None and seg_id in ignore_seg_ids: continue - has_bg_label = bg_annotations[mask].sum() > 0 + has_bg_label = False if bg_annotations is None else bg_annotations[mask].sum() > 0 # find the overlapping label ids this_labels = np.unique(labels[mask]) diff --git a/scripts/segmentation/validation/refine_annotations.py b/scripts/segmentation/validation/refine_annotations.py index f58df90..c8fbf2f 100644 --- a/scripts/segmentation/validation/refine_annotations.py +++ b/scripts/segmentation/validation/refine_annotations.py @@ -33,8 +33,8 @@ def compute_masks(seg, labels, ignore_seg_ids): def refine(seg_path, seg_key, ignore_seg_ids, orientation, slice_id, project_folder, - annotation_path='/g/arendt/...', - raw_path='/g/arendt/...', + annotation_path='/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation/validation_annotations.h5', + raw_path='/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/sbem-6dpf-1-whole-raw.h5', raw_key='t00000/s00/1/cells'): label_path = os.path.join(project_folder, 'labels.npy') @@ -47,6 +47,8 @@ def refine(seg_path, seg_key, ignore_seg_ids, labels = np.load(label_path) if os.path.exists(label_path) else None fm = np.load(fm_path) if os.path.exists(fm_path) else None fs = np.load(fs_path) if os.path.exists(fs_path) else None + else: + labels, fm, fs = None, None, None with open_file(annotation_path, 'r') as fval: ds = fval[orientation][str(slice_id)] @@ -55,8 +57,8 @@ def refine(seg_path, seg_key, ignore_seg_ids, if labels is None: labels = ds[:] - starts = [b.start for b in bb] - stops = [b.stop for b in bb] + starts = [int(b.start) for b in bb] + stops = [int(b.stop) for b in bb] with open_file(seg_path, 'r') as f: ds = f[seg_key] @@ -68,7 +70,7 @@ def refine(seg_path, seg_key, ignore_seg_ids, ds.n_threads = 8 raw = ds[bb].squeeze() - assert labels.shape == seg.shape + assert labels.shape == seg.shape == raw.shape if fm is None: assert fs is None fm, fs = compute_masks(seg, labels, ignore_seg_ids) diff --git a/segmentation/correction/correct_cell_segmentation.py b/segmentation/correction/correct_cell_segmentation.py new file mode 100644 index 0000000..d3b88ac --- /dev/null +++ b/segmentation/correction/correct_cell_segmentation.py @@ -0,0 +1,175 @@ +import os +import json +from scripts.segmentation.correction import (preprocess, + AnnotationTool, + CorrectionTool, + export_node_labels, + to_paintera_format, + rank_false_merges, + get_ignore_ids) + + +def run_preprocessing(project_folder): + os.makedirs(project_folder, exist_ok=True) + + path = '/g/kreshuk/data/arendt/platyneris_v1/data.n5' + key = 'volumes/paintera/proofread_cells' + aff_key = 'volumes/curated_affinities/s1' + + out_path = os.path.join(project_folder, 'data.n5') + out_key = 'segmentation' + + tissue_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data', + 'rawdata/sbem-6dpf-1-whole-segmented-tissue-labels.h5') + tissue_key = 't00000/s00/0/cells' + + preprocess(path, key, aff_key, tissue_path, tissue_key, + out_path, out_key) + + +def run_heuristics(project_folder): + path = os.path.join(project_folder, 'data.n5') + ignore_ids = get_ignore_ids(path, 'tissue_labels') + + out_path_ids = os.path.join(project_folder, 'fm_candidate_ids.json') + out_path_scores = os.path.join(project_folder, 'fm_candidate_scores.json') + n_threads = 32 + n_candidates = 10000 + + rank_false_merges(path, 's0/graph', 'features', + 'morphology', path, 'node_labels', + ignore_ids, out_path_ids, out_path_scores, + n_threads, n_candidates) + + +def run_annotations(id_path, project_folder, scale=2, + with_node_labels=True): + p1 = os.path.join(project_folder, 'data.n5') + table_key = 'morphology' + scale_factor = 2 ** scale + + p2 = '/g/kreshuk/data/arendt/platyneris_v1/data.n5' + rk = 'volumes/raw/s%i' % (scale + 1,) + + if with_node_labels: + node_label_path = p1 + node_label_key = 'node_labels' + wsp = p2 + wsk = 'volumes/paintera/proofread_cells/data/s%i' % scale + else: + node_label_path, node_label_key = None, None + wsp = p1 + wsk = 'segmentation/s%i' % scale + + annotator = AnnotationTool(project_folder, id_path, + p1, table_key, scale_factor, + p2, rk, wsp, wsk, + node_label_path, node_label_key) + annotator() + + +def filter_fm_ids_by_annotations(project_folder, annotation_path): + annotation = 'revisit' + with open(annotation_path) as f: + annotations = json.load(f) + annotations = {int(k): v for k, v in annotations.items()} + ids = [k for k, v in annotations.items() if v == annotation] + print("Found", len(ids), "ids with annotation", annotation) + out_path = os.path.join(project_folder, 'fm_ids_filtered.json') + with open(out_path, 'w') as f: + json.dump(ids, f) + return out_path + + +def run_correction(project_folder, fm_id_path, scale=2): + p1 = os.path.join(project_folder, 'data.n5') + table_key = 'morphology' + segk = 'segmentation/s%i' % scale + scale_factor = 2 ** scale + + gk = 's0/graph' + fk = 'features' + + p2 = '/g/kreshuk/data/arendt/platyneris_v1/data.n5' + rk = 'volumes/raw/s%i' % (scale + 1,) + wsk = 'volumes/paintera/proofread_cells/data/s%i' % scale + + # table path + correcter = CorrectionTool(project_folder, fm_id_path, + p1, table_key, scale_factor, + p2, rk, p1, segk, p2, wsk, + p1, 'node_labels', + p1, gk, fk) + correcter() + + +def export_correction(project_folder, correct_merges=True, zero_out=True): + from scripts.segmentation.correction.export_node_labels import zero_out_ids + + p = os.path.join(project_folder, 'data.n5') + in_key = 'node_labels' + out_key = 'node_labels_corrected' + + if correct_merges: + print("Correcting merged ids") + project_folder = './proj_correct2' + export_node_labels(p, in_key, out_key, project_folder) + next_in_key = out_key + else: + next_in_key = in_key + + def get_zero_ids(annotation_path): + annotation = 'merge' + with open(annotation_path) as f: + annotations = json.load(f) + annotations = {int(k): v for k, v in annotations.items()} + zero_ids = [k for k, v in annotations.items() if v == annotation] + return zero_ids + + if zero_out: + # read merge annotations from the morphology annotator + annotation_path = './proj_annotate_morphology/annotations.json' + zero_ids = get_zero_ids(annotation_path) + # read additional merge annotations + annotation_path = './proj_correct2/annotations.json' + zero_ids += get_zero_ids(annotation_path) + print("Zeroing out", len(zero_ids), "ids") + + zero_out_ids(p, next_in_key, p, out_key, zero_ids) + + paintera_path = '/g/kreshuk/data/arendt/platyneris_v1/data.n5' + paintera_key = 'volumes/paintera/proofread_cells_multiset/fragment-segment-assignment' + print("Exporting to paintera format") + to_paintera_format(p, out_key, paintera_path, paintera_key) + + +def correction_workflow_from_ranked_false_merges(): + # the project folder where we store intermediate results + # for the false merge correction + project_folder = './project_correct_false_merges' + + # compute the segmentation, the region adjacency graph and the graph weights + run_preprocessing(project_folder) + + # compute the false merge heuristics based on a morphology criterion + # (default: number of connected components per slice) + fm_id_path = run_heuristics(project_folder) + + # run annotations to quickly filter for the false merges that need to be + # corrected. this is not strictly necessary (we can use run_correction directly) + # but in my experience, it is faster to first just filter for false merges + # and then to correct them. + run_annotations(fm_id_path, project_folder) + annotation_path = os.path.join(project_folder, 'annotations.json') + fm_id_path = filter_fm_ids_by_annotations(project_folder, + fm_id_path, annotation_path) + + # run the false merge correction tool + run_correction(project_folder, fm_id_path) + + # export the corrected node labels to paintera + export_correction(project_folder) + + +if __name__ == '__main__': + correction_workflow_from_ranked_false_merges() diff --git a/segmentation/correction/deprecated.py b/segmentation/correction/deprecated.py new file mode 100644 index 0000000..a03ded7 --- /dev/null +++ b/segmentation/correction/deprecated.py @@ -0,0 +1,47 @@ +# +# just keeping this here for reference +# + +import os +import json + + +def get_skipped_ids(out_path): + processed_ids = './proj_correct_fms/processed_ids.json' + res_dir = './proj_correct_fms/results' + with open(processed_ids) as f: + processed_ids = json.load(f) + + saved_ids = os.listdir(res_dir) + saved_ids = [int(os.path.splitext(sid)[0]) for sid in saved_ids] + + skipped_ids = list(set(processed_ids) - set(saved_ids)) + print("N-skipped:") + print(len(skipped_ids)) + + with open(out_path, 'w') as f: + json.dump(skipped_ids, f) + + +def check_export(): + from scripts.segmentation.correction.export_node_labels import check_exported_paintera + + path1 = '/g/kreshuk/data/arendt/platyneris_v1/data.n5' + assignment_key = 'volumes/paintera/proofread_cells/fragment-segment-assignment' + raw_key = 'volumes/raw/s3' + seg_key = 'volumes/paintera/proofread_cells/data/s2' + + path2 = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/analysis/correction/data.n5' + node_label_key = 'node_labels' + table_key = 'morphology' + scale_factor = 4 + + res_folder = './proj_correct_fms/results' + ids = os.listdir(res_folder)[:10] + ids = [int(os.path.splitext(idd)[0]) for idd in ids] + + check_exported_paintera(path1, assignment_key, + path2, node_label_key, + path2, table_key, scale_factor, + path1, raw_key, path1, seg_key, + ids) diff --git a/segmentation/correction/fix_assignments.py b/segmentation/correction/fix_assignments.py new file mode 100644 index 0000000..e8ad6ee --- /dev/null +++ b/segmentation/correction/fix_assignments.py @@ -0,0 +1,69 @@ +import os +import json + + +def get_assignments(version): + from scripts.segmentation.correction.assignment_diffs import node_labels + assert version in ('0.5.5', '0.6.1', 'local') + ws_path = '/g/kreshuk/data/arendt/platyneris_v1/data.n5' + ws_key = 'volumes/paintera/proofread_cells_multiset/data/s0' + + if version == 'local': + in_path = './data.n5' + in_key = 'segmentation/s0' + prefix = version + else: + in_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data', + '%s/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5' % version) + in_key = 't00000/s00/0/cells' + + prefix = ''.join(version.split('.')) + + out_path = './data.n5' + out_key = 'assignments/%s' % prefix + + tmp_folder = './tmp_%s' % prefix + + node_labels(ws_path, ws_key, in_path, in_key, + out_path, out_key, prefix, tmp_folder) + + +def get_split_assignments(): + import z5py + from scripts.segmentation.correction.assignment_diffs import assignment_diff_splits + with z5py.File('./data.n5') as f: + ref = f['assignments/055'][:] + new = f['assignments/corrected'][:] + + splits = assignment_diff_splits(ref, new) + + print(len(splits)) + + # with open('./split_assignments_local.json', 'w') as f: + # json.dump(splits, f) + + +def blub(): + with open('./split_assignments.json') as f: + split_assignments = json.load(f) + ids = list(split_assignments.keys()) + # vals = list(split_assignments.values()) + + correction_path = './proj_correct2/results' + correct_ids = os.listdir(correction_path) + correct_ids = [int(os.path.splitext(cid)[0]) for cid in correct_ids] + + print("Number of assignments which are split:", len(ids)) + print("Number of assignments which were supposed to be split:", len(correct_ids)) + + diff = set(ids) - set(correct_ids) + print("Number after diff:", len(diff)) + + +if __name__ == '__main__': + # get_assignments('0.5.5') + # get_assignments('0.6.1') + # get_assignments('local') + + get_split_assignments() + # blub() diff --git a/segmentation/correction/morphology_outlier.py b/segmentation/correction/morphology_outlier.py new file mode 100644 index 0000000..5a2cdfd --- /dev/null +++ b/segmentation/correction/morphology_outlier.py @@ -0,0 +1,53 @@ +import os +import json +import numpy as np +import z5py +from scripts.segmentation.correction import preprocess +from scripts.segmentation.correction.heuristics import (compute_ratios, + components_per_slice, + get_ignore_ids) + + +def run_preprocessing(): + path = '/g/kreshuk/data/arendt/platyneris_v1/data.n5' + key = 'volumes/paintera/proofread_cells' + aff_key = 'volumes/curated_affinities/s1' + + out_path = './data.n5' + out_key = 'segmentation' + + tissue_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data', + 'rawdata/sbem-6dpf-1-whole-segmented-tissue-labels.h5') + tissue_key = 't00000/s00/0/cells' + + preprocess(path, key, aff_key, tissue_path, tissue_key, + out_path, out_key) + + +def morphology_outlier(): + scale = 2 + p1 = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/analysis/correction/data.n5' + seg_key = 'segmentation/s%i' % scale + table_key = 'morphology' + scale_factor = 2 ** scale + + with z5py.File(p1, 'r') as f: + ds = f['segmentation/s0'] + n_ids = ds.attrs['maxId'] + 1 + + seg_ids = np.arange(n_ids) + ignore_ids = get_ignore_ids(p1, 'tissue_labels') + ignore_id_mask = ~np.isin(seg_ids, ignore_ids) + seg_ids = seg_ids[ignore_id_mask] + + seg_ids, ratios_close = compute_ratios(p1, seg_key, p1, table_key, scale_factor, + 64, components_per_slice, seg_ids, True) + with open('./ratios_components.json', 'w') as f: + json.dump(ratios_close.tolist(), f) + with open('./ids_morphology.json', 'w') as f: + json.dump(seg_ids.tolist(), f) + + +if __name__ == '__main__': + # run_preprocessing() + morphology_outlier() diff --git a/segmentation/validation/check_validation.py b/segmentation/validation/check_validation.py new file mode 100644 index 0000000..d85f41f --- /dev/null +++ b/segmentation/validation/check_validation.py @@ -0,0 +1,106 @@ +import os +import h5py +from heimdall import view, to_source +from scripts.segmentation.validation import eval_cells, eval_nuclei, get_ignore_seg_ids + +# ROOT_FOLDER = '../../data' +ROOT_FOLDER = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data' + + +def check_cell_evaluation(): + from scripts.segmentation.validation.eval_cells import (eval_slice, + get_bounding_box) + + praw = os.path.join(ROOT_FOLDER, 'rawdata/sbem-6dpf-1-whole-raw.h5') + pseg = os.path.join(ROOT_FOLDER, '0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5') + pann = os.path.join(ROOT_FOLDER, 'rawdata/evaluation/validation_annotations.h5') + + table_path = '../../data/0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv' + ignore_seg_ids = get_ignore_seg_ids(table_path) + + with h5py.File(pseg, 'r') as fseg, h5py.File(pann, 'r') as fann: + ds_seg = fseg['t00000/s00/0/cells'] + ds_ann = fann['xy/1000'] + + print("Run evaluation ...") + res, masks = eval_slice(ds_seg, ds_ann, ignore_seg_ids, min_radius=16, + return_masks=True) + fm, fs = masks['merges'], masks['splits'] + print() + print("Eval result") + print(res) + print() + + print("Load raw data ...") + bb = get_bounding_box(ds_ann) + with h5py.File(praw, 'r') as f: + raw = f['t00000/s00/1/cells'][bb].squeeze() + + print("Load seg data ...") + seg = ds_seg[bb].squeeze().astype('uint32') + + view(to_source(raw, name='raw'), to_source(seg, name='seg'), + to_source(fm, name='merges'), to_source(fs, name='splits')) + + +def eval_all_cells(): + pseg = os.path.join(ROOT_FOLDER, '0.5.5/segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5') + pann = os.path.join(ROOT_FOLDER, 'rawdata/evaluation/validation_annotations.h5') + + table_path = os.path.join(ROOT_FOLDER, '0.5.5/tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv') + ignore_seg_ids = get_ignore_seg_ids(table_path) + + res = eval_cells(pseg, 't00000/s00/0/cells', pann, + ignore_seg_ids=ignore_seg_ids) + print("Eval result:") + print(res) + + +def check_nucleus_evaluation(): + from scripts.segmentation.validation.eval_nuclei import (eval_slice, + get_bounding_box) + + praw = os.path.join(ROOT_FOLDER, 'rawdata/sbem-6dpf-1-whole-raw.h5') + pseg = os.path.join(ROOT_FOLDER, '0.0.0/segmentations/sbem-6dpf-1-whole-segmented-nuclei-labels.h5') + pann = os.path.join(ROOT_FOLDER, 'rawdata/evaluation/validation_annotations.h5') + + with h5py.File(pseg, 'r') as fseg, h5py.File(pann, 'r') as fann: + ds_seg = fseg['t00000/s00/0/cells'] + ds_ann = fann['xy/1000'] + + print("Run evaluation ...") + res, masks = eval_slice(ds_seg, ds_ann, min_radius=6, + return_masks=True) + fm, fs = masks['merges'], masks['splits'] + print() + print("Eval result") + print(res) + print() + + print("Load raw data ...") + bb = get_bounding_box(ds_ann, scale_factor=4.) + with h5py.File(praw, 'r') as f: + raw = f['t00000/s00/3/cells'][bb].squeeze() + + print("Load seg data ...") + seg = ds_seg[bb].squeeze().astype('uint32') + + view(to_source(raw, name='raw'), to_source(seg, name='seg'), + to_source(fm, name='merges'), to_source(fs, name='splits')) + + +def eval_all_nulcei(): + pseg = os.path.join(ROOT_FOLDER, '0.0.0/segmentations/sbem-6dpf-1-whole-segmented-nuclei-labels.h5') + pann = os.path.join(ROOT_FOLDER, 'rawdata/evaluation/validation_annotations.h5') + + res = eval_nuclei(pseg, 't00000/s00/0/cells', pann) + print("Eval result:") + print(res) + + +if __name__ == '__main__': + # check_cell_evaluation() + eval_all_cells() + + # check_nucleus_evaluation() + # eval_all_nulcei() diff --git a/segmentation/validation/eval_cell_segmentation.py b/segmentation/validation/eval_cell_segmentation.py new file mode 100644 index 0000000..20e1db5 --- /dev/null +++ b/segmentation/validation/eval_cell_segmentation.py @@ -0,0 +1,94 @@ +import argparse +import os +import numpy as np +import h5py +from scripts.segmentation.validation import eval_cells, get_ignore_seg_ids +from scripts.attributes.region_attributes import region_attributes +from scripts.default_config import write_default_global_config + +ANNOTATIONS = '../../data/rawdata/evaluation/validation_annotations.h5' +BASELINES = '../../data/rawdata/evaluation/baseline_cell_segmentations.h5' + + +def get_label_ids(path, key): + with h5py.File(path, 'r') as f: + ds = f[key] + max_id = ds.attrs['maxId'] + label_ids = np.arange(max_id + 1) + return label_ids + + +def compute_baseline_tables(): + names = ['lmc', 'mc', 'curated_lmc', 'curated_mc'] + path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation', + 'baseline_cell_segmentations.h5') + table_prefix = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation' + im_folder = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.6.0/images' + seg_folder = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.6.0/segmentations' + for name in names: + key = name + out_path = os.path.join(table_prefix, '%s.csv' % name) + tmp_folder = './tmp_regions_%s' % name + config_folder = os.path.join(tmp_folder, 'configs') + write_default_global_config(config_folder) + label_ids = get_label_ids(path, key) + region_attributes(path, out_path, im_folder, seg_folder, + label_ids, tmp_folder, target='local', max_jobs=64, + key_seg=key) + + +def eval_seg(path, key, table): + ignore_ids = get_ignore_seg_ids(table) + fm, fs, tot = eval_cells(path, key, ANNOTATIONS, + ignore_seg_ids=ignore_ids) + print("Evaluation yields:") + print("False merges:", fm) + print("False splits:", fs) + print("Total number of annotations:", tot) + + +def eval_baselines(): + path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation', + 'baseline_cell_segmentations.h5') + names = ['lmc', 'mc', 'curated_lmc', 'curated_mc'] + table_prefix = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/rawdata/evaluation' + results = {} + for name in names: + print("Run evaluation for %s ..." % name) + table = os.path.join(table_prefix, '%s.csv' % name) + ignore_ids = get_ignore_seg_ids(table) + key = name + fm, fs, tot = eval_cells(path, key, ANNOTATIONS, + ignore_seg_ids=ignore_ids) + results[name] = (fm, fs, tot) + + for name in names: + print("Evaluation of", name, "yields:") + print("False merges:", fm) + print("False splits:", fs) + print("Total number of annotations:", tot) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("path", type=str, help="Path to segmentation that should be validated.") + parser.add_argument("table", type=str, help="Path to table with region/semantic assignments") + parser.add_argument("--key", type=str, default="t00000/s00/0/cells", help="Segmentation key") + parser.add_argument("--baselines", type=int, default=0, + help="Whether to evaluate the baseline segmentations (overrides path)") + args = parser.parse_args() + + baselines = bool(args.baselines) + if baselines: + eval_baselines() + else: + path = args.path + table = args.table + key = args.key + assert os.path.exists(path), path + eval_seg(path, key, table) + + +if __name__ == '__main__': + # compute_baseline_tables() + main() diff --git a/segmentation/validation/eval_nucleus_segmentation.py b/segmentation/validation/eval_nucleus_segmentation.py new file mode 100644 index 0000000..98aa0e1 --- /dev/null +++ b/segmentation/validation/eval_nucleus_segmentation.py @@ -0,0 +1,29 @@ +import argparse +import os +from scripts.segmentation.validation import eval_nuclei + +ANNOTATIONS = '../../data/rawdata/evaluation/validation_annotations.h5' + + +def eval_seg(path, key): + fp, fn, tot = eval_nuclei(path, key, ANNOTATIONS) + print("Evaluation yields:") + print("False positives:", fp) + print("False negatives:", fn) + print("Total number of annotations:", tot) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("path", type=str, help="Path to nuclei segmentation that should be validated.") + parser.add_argument("--key", type=str, default="t00000/s00/0/cells", help="Segmentation key") + args = parser.parse_args() + + path = args.path + key = args.key + assert os.path.exists(path), path + eval_seg(path, key) + + +if __name__ == '__main__': + main() diff --git a/segmentation/validation/proofread_annotations.py b/segmentation/validation/proofread_annotations.py new file mode 100644 index 0000000..0fd4e1a --- /dev/null +++ b/segmentation/validation/proofread_annotations.py @@ -0,0 +1,29 @@ +import os +from scripts.segmentation.validation.refine_annotations import refine, export_refined +from scripts.segmentation.validation.eval_cells import get_ignore_seg_ids + + +def proofread(orientation, slice_id): + seg_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.6.1', + 'segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5') + seg_key = 't00000/s00/0/cells' + table_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.6.1', + 'tables/sbem-6dpf-1-whole-segmented-cells-labels/regions.csv') + ignore_seg_ids = get_ignore_seg_ids(table_path) + + proj_folder = 'project_folders/proofread_%s_%i' % (orientation, slice_id) + refine(seg_path, seg_key, ignore_seg_ids, + orientation, slice_id, proj_folder) + + +# orientation xy has slices: +# 1000, 2000, 4000, 7000 +# orientation xz has slices: +# 4000, 5998, 8662, 9328 +if __name__ == '__main__': + orientation = 'xz' + slice_id = 4000 + proofread(orientation, slice_id) + # done: + # xy: 1000, 2000, 4000, 7000 + # xz: 4000 (partially) -- GitLab