From 589775d88be73f81e3847f92ccbf41f199d8c696 Mon Sep 17 00:00:00 2001 From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de> Date: Mon, 2 Sep 2019 11:13:05 +0200 Subject: [PATCH] Start to implment muscle correction --- scripts/segmentation/muscle/__init__.py | 2 + scripts/segmentation/muscle/muscle_mapping.py | 127 ++++++++++++++ scripts/segmentation/muscle/workflow.py | 160 ++++++++++++++++++ segmentation/muscles.py | 19 +++ 4 files changed, 308 insertions(+) create mode 100644 scripts/segmentation/muscle/__init__.py create mode 100644 scripts/segmentation/muscle/muscle_mapping.py create mode 100644 scripts/segmentation/muscle/workflow.py create mode 100644 segmentation/muscles.py diff --git a/scripts/segmentation/muscle/__init__.py b/scripts/segmentation/muscle/__init__.py new file mode 100644 index 0000000..08487e7 --- /dev/null +++ b/scripts/segmentation/muscle/__init__.py @@ -0,0 +1,2 @@ +from.muscle_mapping import predict_muscle_mapping +from .workflow import run_workflow diff --git a/scripts/segmentation/muscle/muscle_mapping.py b/scripts/segmentation/muscle/muscle_mapping.py new file mode 100644 index 0000000..4fa6c52 --- /dev/null +++ b/scripts/segmentation/muscle/muscle_mapping.py @@ -0,0 +1,127 @@ +import os +import h5py +import numpy as np +import pandas as pd + +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import StratifiedKFold + + +def compute_labels(root): + table = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', 'regions.csv') + table = pd.read_csv(table, sep='\t') + label_ids = table['label_id'].values + labels = table['muscle'].values.astype('uint8') + assert np.array_equal(np.unique(labels), [0, 1]) + return labels, label_ids + + +def compute_features(root): + + feature_names = [] + + # we take the number of pixels and calculat the size of + # the bounding bpx from the default table + default = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', 'default.csv') + default = pd.read_csv(default, sep='\t') + + n_pixels = default['n_pixels'].values + bb_min = np.array([default['bb_min_z'].values, + default['bb_min_y'].values, + default['bb_min_x'].values]).T + bb_max = np.array([default['bb_max_z'].values, + default['bb_max_y'].values, + default['bb_max_x'].values]).T + bb_shape = bb_max - bb_min + features_def = np.concatenate([n_pixels[:, None], bb_shape], axis=1) + feature_names.extend(['n_pixels', 'bb_shape_z', 'bb_shape_y', 'bb_shape_x']) + + morpho = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', 'morphology.csv') + morpho = pd.read_csv(morpho, sep='\t') + label_ids_morpho = morpho['label_id'].values.astype('uint64') + features_morpho = morpho[['shape_volume_in_microns', 'shape_extent', + 'shape_surface_area', 'shape_sphericity']].values + feature_names.extend(['volume', 'extent', 'surface_area', 'sphericty']) + + # add the nucleus features + nucleus_mapping = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', + 'cells_to_nuclei.csv') + nucleus_mapping = pd.read_csv(nucleus_mapping, sep='\t') + label_ids_nuclei = nucleus_mapping['label_id'].values.astype('uint64') + nucleus_ids = nucleus_mapping['nucleus_id'].values.astype('uint64') + nucleus_mask = nucleus_ids > 0 + nucleus_ids = nucleus_ids[nucleus_mask] + label_ids_nuclei = label_ids_nuclei[nucleus_mask] + + nucleus_features = os.path.join(root, 'sbem-6dpf-1-whole-segmented-nuclei-labels', + 'morphology.csv') + nucleus_features = pd.read_csv(nucleus_features, sep='\t') + nucleus_features.set_index('label_id', inplace=True) + nucleus_features = nucleus_features.loc[nucleus_ids] + nucleus_features = nucleus_features[['shape_volume_in_microns', 'shape_extent', + 'shape_surface_area', 'shape_sphericity', + 'intensity_mean', 'intensity_st_dev']].values + feature_names.extend(['nucleus_volume', 'nucleus_extent', 'nucleus_surface_area', + 'nucleus_sphericity', 'nucleus_intensity_mean', 'nucleus_intensity_std']) + + # TODO + # combine the features + label_id_mask = np.zeros(len(default), dtype='uint8') + label_id_mask[label_ids_morpho] += 1 + label_id_mask[label_ids_nuclei] += 1 + label_id_mask = label_id_mask > 1 + valid_label_ids = np.where(label_id_mask)[0] + + label_id_mask_morpho = np.isin(label_ids_morpho, valid_label_ids) + label_id_mask_nuclei = np.isin(label_ids_nuclei, valid_label_ids) + + features = np.concatenate([features_def[label_id_mask], + features_morpho[label_id_mask_morpho], + nucleus_features[label_id_mask_nuclei]], axis=1) + assert len(features) == len(valid_label_ids), "%i, %i" % (len(features), + len(valid_label_ids)) + return features, valid_label_ids, feature_names + + +def compute_features_and_labels(root): + features, label_ids, _ = compute_features(root) + labels, _ = compute_labels(root) + + labels = labels[label_ids] + assert len(labels) == len(features) == len(label_ids) + return features, labels, label_ids + + +def predict_muscle_mapping(root, project_path, n_folds=4): + print("Computig labels and features ...") + features, labels, label_ids = compute_features_and_labels(root) + print("Found", len(features), "samples and", features.shape[1], "features per sample") + kf = StratifiedKFold(n_splits=n_folds, shuffle=True) + false_pos_ids = [] + false_neg_ids = [] + + print("Find false positives and negatives on", n_folds, "folds") + for train_idx, test_idx in kf.split(features, labels): + x, y = features[train_idx], labels[train_idx] + rf = RandomForestClassifier(n_estimators=50, n_jobs=8) + rf.fit(x, y) + + x, y = features[test_idx], labels[test_idx] + # allow adapting the threshold ? + pred = rf.predict(x) + + false_pos = np.logical_and(pred == 1, y == 0) + false_neg = np.logical_and(pred == 0, y == 1) + + fold_labels = label_ids[test_idx] + false_pos_ids.extend(fold_labels[false_pos].tolist()) + false_neg_ids.extend(fold_labels[false_neg].tolist()) + + print("Found", len(false_pos_ids), "false positves") + print("Found", len(false_neg_ids), "false negatitves") + + # save false pos and false negatives + print("Serializing results to", project_path) + with h5py.File(project_path) as f: + f.create_dataset('false_positives/prediction', data=false_pos_ids) + f.create_dataset('false_negatives/prediction', data=false_neg_ids) diff --git a/scripts/segmentation/muscle/workflow.py b/scripts/segmentation/muscle/workflow.py new file mode 100644 index 0000000..abae948 --- /dev/null +++ b/scripts/segmentation/muscle/workflow.py @@ -0,0 +1,160 @@ +import os +import h5py +import z5py +import pandas as pd +import napari +from heimdall import view, to_source +from elf.wrapper.resized_volume import ResizedVolume + + +def scale_to_res(scale): + res = {1: [.025, .02, .02], + 2: [.05, .04, .04], + 3: [.1, .08, .08], + 4: [.2, .16, .16], + 5: [.4, .32, .32]} + return res[scale] + + +def get_bb(table, lid, res): + row = table.loc[lid] + bb_min = [row.bb_min_z, row.bb_min_y, row.bb_min_x] + bb_max = [row.bb_max_z, row.bb_max_y, row.bb_max_x] + return tuple(slice(int(mi / re), int(ma / re)) for mi, ma, re in zip(bb_min, bb_max, res)) + + +def view_candidate(raw, mask, muscle): + save_id, false_merge, save_state, done = False, False, False, False + with napari.gui_qt(): + viewer = view(to_source(raw, name='raw'), + to_source(mask, name='prediction'), + to_source(muscle, name='muscle-segmentation'), + return_viewer=True) + + # add key bindings + @viewer.bind_key('y') + def confirm_id(viewer): + print("Confirm id requested") + nonlocal save_id + save_id = True + + @viewer.bind_key('f') + def add_false_merge(viewer): + print("False merge requested") + nonlocal false_merge + false_merge = True + + @viewer.bind_key('s') + def save(viewer): + print("Save state requested") + nonlocal save_state + save_state = True + + @viewer.bind_key('q') + def quit_(viewer): + print("Quit requested") + nonlocal done + done = True + + return save_id, false_merge, save_state, done + + +def check_ids(remaining_ids, saved_ids, false_merges, project_path, state): + scale = 3 + pathr = '/g/kreshuk/data/arendt/platyneris_v1/data.n5' + paths = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.3.1', + 'segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5') + table_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.3.1', + 'tables/sbem-6dpf-1-whole-segmented-cells-labels/default.csv') + + table = pd.read_csv(table_path, sep='\t') + res = scale_to_res(scale) + + fr = z5py.File(pathr, 'r') + dsr = fr['volumes/raw/s%i' % scale] + km = 'volumes/labels/muscle' + dsm = fr[km] + dsm = ResizedVolume(dsm, shape=dsr.shape) + assert dsm.shape == dsr.shape + + check_fps, current_id = state['check_fps'], state['current_id'] + + with h5py.File(paths, 'r') as fs: + dss = fs['t00000/s00/%i/cells' % (scale - 1,)] + + for ii, fid in enumerate(remaining_ids): + bb = get_bb(table, fid, res) + if check_fps: + print("Checking false positives - id:", fid) + else: + print("Checking false negatives - id:", fid) + raw = dsr[bb] + seg = dss[bb] + muscle = dsm[bb] + muscle = (muscle > 0).astype('uint32') + mask = (seg == fid).astype('uint32') + save_id, false_merge, save_state, done = view_candidate(raw, mask, muscle) + + if save_id: + saved_ids.append(fid) + print("Confirm id", fid, "we now have", len(saved_ids), "confirmed ids.") + + if false_merge: + print("Add id", fid, "to false merges") + false_merges.append(fid) + + if save_state: + print("Save current state to", project_path) + with h5py.File(project_path) as f: + f.attrs['check_fps'] = check_fps + + if 'false_merges' in f: + del f['false_merges'] + if len(false_merges) > 0: + f.create_dataset('false_merges', data=false_merges) + + g = f['false_positives'] if check_fps else f['false_negatives'] + g.attrs['current_id'] = current_id + ii + 1 + if 'proofread' in g: + del g['proofread'] + if len(saved_ids) > 0: + g.create_dataset('proofread', data=saved_ids) + + if done: + print("Quit") + return False + return True + + +def load_state(g): + current_id = g.attrs.get('current_id', 0) + remaining_ids = g['prediction'][current_id:] + saved_ids = g['proofread'][:].tolist() if 'proofread' in g else [] + return remaining_ids, saved_ids, current_id + + +def load_false_merges(f): + false_merges = f['false_merges'][:].tolist() if 'false_merges' in f else [] + return false_merges + + +def run_workflow(project_path): + print("Start muscle proofreading workflow from", project_path) + with h5py.File(project_path, 'r') as f: + attrs = f.attrs + check_fps = attrs.get('check_fps', True) + false_merges = load_false_merges(f) + g = f['false_positives'] if check_fps else f['false_negatives'] + remaining_ids, saved_ids, current_id = load_state(g) + + state = {'check_fps': check_fps, 'current_id': current_id} + print("Continue workflow for", "false positives" if check_fps else "false negatives", "from id", current_id) + done = check_ids(remaining_ids, saved_ids, false_merges, project_path, state) + + if check_fps and done: + with h5py.File(project_path, 'r') as f: + g = f['false_negatives'] + remaining_ids, saved_ids, current_id = load_state(g) + state = {'check_fps': False, 'current_id': current_id} + print("Start workflow for false negatives from id", current_id) + check_ids(remaining_ids, saved_ids, false_merges, project_path, state) diff --git a/segmentation/muscles.py b/segmentation/muscles.py new file mode 100644 index 0000000..067f242 --- /dev/null +++ b/segmentation/muscles.py @@ -0,0 +1,19 @@ +#! /g/arendt/pape/miniconda3/envs/platybrowser/bin/python +from scripts.segmentation.muscle import predict_muscle_mapping, run_workflow + + +PROJECT_PATH = '/g/kreshuk/pape/Work/muscle_mapping_v1.h5' +ROOT = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.5.1/tables' + + +def precompute(): + predict_muscle_mapping(ROOT, PROJECT_PATH) + + +def proofreading(): + run_workflow(PROJECT_PATH) + + +if __name__ == '__main__': + # precompute() + proofreading() -- GitLab