Skip to content
Snippets Groups Projects
Commit 589775d8 authored by Constantin Pape's avatar Constantin Pape
Browse files

Start to implment muscle correction

parent 926bf834
No related branches found
No related tags found
No related merge requests found
from.muscle_mapping import predict_muscle_mapping
from .workflow import run_workflow
import os
import h5py
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
def compute_labels(root):
table = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', 'regions.csv')
table = pd.read_csv(table, sep='\t')
label_ids = table['label_id'].values
labels = table['muscle'].values.astype('uint8')
assert np.array_equal(np.unique(labels), [0, 1])
return labels, label_ids
def compute_features(root):
feature_names = []
# we take the number of pixels and calculat the size of
# the bounding bpx from the default table
default = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', 'default.csv')
default = pd.read_csv(default, sep='\t')
n_pixels = default['n_pixels'].values
bb_min = np.array([default['bb_min_z'].values,
default['bb_min_y'].values,
default['bb_min_x'].values]).T
bb_max = np.array([default['bb_max_z'].values,
default['bb_max_y'].values,
default['bb_max_x'].values]).T
bb_shape = bb_max - bb_min
features_def = np.concatenate([n_pixels[:, None], bb_shape], axis=1)
feature_names.extend(['n_pixels', 'bb_shape_z', 'bb_shape_y', 'bb_shape_x'])
morpho = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', 'morphology.csv')
morpho = pd.read_csv(morpho, sep='\t')
label_ids_morpho = morpho['label_id'].values.astype('uint64')
features_morpho = morpho[['shape_volume_in_microns', 'shape_extent',
'shape_surface_area', 'shape_sphericity']].values
feature_names.extend(['volume', 'extent', 'surface_area', 'sphericty'])
# add the nucleus features
nucleus_mapping = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels',
'cells_to_nuclei.csv')
nucleus_mapping = pd.read_csv(nucleus_mapping, sep='\t')
label_ids_nuclei = nucleus_mapping['label_id'].values.astype('uint64')
nucleus_ids = nucleus_mapping['nucleus_id'].values.astype('uint64')
nucleus_mask = nucleus_ids > 0
nucleus_ids = nucleus_ids[nucleus_mask]
label_ids_nuclei = label_ids_nuclei[nucleus_mask]
nucleus_features = os.path.join(root, 'sbem-6dpf-1-whole-segmented-nuclei-labels',
'morphology.csv')
nucleus_features = pd.read_csv(nucleus_features, sep='\t')
nucleus_features.set_index('label_id', inplace=True)
nucleus_features = nucleus_features.loc[nucleus_ids]
nucleus_features = nucleus_features[['shape_volume_in_microns', 'shape_extent',
'shape_surface_area', 'shape_sphericity',
'intensity_mean', 'intensity_st_dev']].values
feature_names.extend(['nucleus_volume', 'nucleus_extent', 'nucleus_surface_area',
'nucleus_sphericity', 'nucleus_intensity_mean', 'nucleus_intensity_std'])
# TODO
# combine the features
label_id_mask = np.zeros(len(default), dtype='uint8')
label_id_mask[label_ids_morpho] += 1
label_id_mask[label_ids_nuclei] += 1
label_id_mask = label_id_mask > 1
valid_label_ids = np.where(label_id_mask)[0]
label_id_mask_morpho = np.isin(label_ids_morpho, valid_label_ids)
label_id_mask_nuclei = np.isin(label_ids_nuclei, valid_label_ids)
features = np.concatenate([features_def[label_id_mask],
features_morpho[label_id_mask_morpho],
nucleus_features[label_id_mask_nuclei]], axis=1)
assert len(features) == len(valid_label_ids), "%i, %i" % (len(features),
len(valid_label_ids))
return features, valid_label_ids, feature_names
def compute_features_and_labels(root):
features, label_ids, _ = compute_features(root)
labels, _ = compute_labels(root)
labels = labels[label_ids]
assert len(labels) == len(features) == len(label_ids)
return features, labels, label_ids
def predict_muscle_mapping(root, project_path, n_folds=4):
print("Computig labels and features ...")
features, labels, label_ids = compute_features_and_labels(root)
print("Found", len(features), "samples and", features.shape[1], "features per sample")
kf = StratifiedKFold(n_splits=n_folds, shuffle=True)
false_pos_ids = []
false_neg_ids = []
print("Find false positives and negatives on", n_folds, "folds")
for train_idx, test_idx in kf.split(features, labels):
x, y = features[train_idx], labels[train_idx]
rf = RandomForestClassifier(n_estimators=50, n_jobs=8)
rf.fit(x, y)
x, y = features[test_idx], labels[test_idx]
# allow adapting the threshold ?
pred = rf.predict(x)
false_pos = np.logical_and(pred == 1, y == 0)
false_neg = np.logical_and(pred == 0, y == 1)
fold_labels = label_ids[test_idx]
false_pos_ids.extend(fold_labels[false_pos].tolist())
false_neg_ids.extend(fold_labels[false_neg].tolist())
print("Found", len(false_pos_ids), "false positves")
print("Found", len(false_neg_ids), "false negatitves")
# save false pos and false negatives
print("Serializing results to", project_path)
with h5py.File(project_path) as f:
f.create_dataset('false_positives/prediction', data=false_pos_ids)
f.create_dataset('false_negatives/prediction', data=false_neg_ids)
import os
import h5py
import z5py
import pandas as pd
import napari
from heimdall import view, to_source
from elf.wrapper.resized_volume import ResizedVolume
def scale_to_res(scale):
res = {1: [.025, .02, .02],
2: [.05, .04, .04],
3: [.1, .08, .08],
4: [.2, .16, .16],
5: [.4, .32, .32]}
return res[scale]
def get_bb(table, lid, res):
row = table.loc[lid]
bb_min = [row.bb_min_z, row.bb_min_y, row.bb_min_x]
bb_max = [row.bb_max_z, row.bb_max_y, row.bb_max_x]
return tuple(slice(int(mi / re), int(ma / re)) for mi, ma, re in zip(bb_min, bb_max, res))
def view_candidate(raw, mask, muscle):
save_id, false_merge, save_state, done = False, False, False, False
with napari.gui_qt():
viewer = view(to_source(raw, name='raw'),
to_source(mask, name='prediction'),
to_source(muscle, name='muscle-segmentation'),
return_viewer=True)
# add key bindings
@viewer.bind_key('y')
def confirm_id(viewer):
print("Confirm id requested")
nonlocal save_id
save_id = True
@viewer.bind_key('f')
def add_false_merge(viewer):
print("False merge requested")
nonlocal false_merge
false_merge = True
@viewer.bind_key('s')
def save(viewer):
print("Save state requested")
nonlocal save_state
save_state = True
@viewer.bind_key('q')
def quit_(viewer):
print("Quit requested")
nonlocal done
done = True
return save_id, false_merge, save_state, done
def check_ids(remaining_ids, saved_ids, false_merges, project_path, state):
scale = 3
pathr = '/g/kreshuk/data/arendt/platyneris_v1/data.n5'
paths = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.3.1',
'segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5')
table_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.3.1',
'tables/sbem-6dpf-1-whole-segmented-cells-labels/default.csv')
table = pd.read_csv(table_path, sep='\t')
res = scale_to_res(scale)
fr = z5py.File(pathr, 'r')
dsr = fr['volumes/raw/s%i' % scale]
km = 'volumes/labels/muscle'
dsm = fr[km]
dsm = ResizedVolume(dsm, shape=dsr.shape)
assert dsm.shape == dsr.shape
check_fps, current_id = state['check_fps'], state['current_id']
with h5py.File(paths, 'r') as fs:
dss = fs['t00000/s00/%i/cells' % (scale - 1,)]
for ii, fid in enumerate(remaining_ids):
bb = get_bb(table, fid, res)
if check_fps:
print("Checking false positives - id:", fid)
else:
print("Checking false negatives - id:", fid)
raw = dsr[bb]
seg = dss[bb]
muscle = dsm[bb]
muscle = (muscle > 0).astype('uint32')
mask = (seg == fid).astype('uint32')
save_id, false_merge, save_state, done = view_candidate(raw, mask, muscle)
if save_id:
saved_ids.append(fid)
print("Confirm id", fid, "we now have", len(saved_ids), "confirmed ids.")
if false_merge:
print("Add id", fid, "to false merges")
false_merges.append(fid)
if save_state:
print("Save current state to", project_path)
with h5py.File(project_path) as f:
f.attrs['check_fps'] = check_fps
if 'false_merges' in f:
del f['false_merges']
if len(false_merges) > 0:
f.create_dataset('false_merges', data=false_merges)
g = f['false_positives'] if check_fps else f['false_negatives']
g.attrs['current_id'] = current_id + ii + 1
if 'proofread' in g:
del g['proofread']
if len(saved_ids) > 0:
g.create_dataset('proofread', data=saved_ids)
if done:
print("Quit")
return False
return True
def load_state(g):
current_id = g.attrs.get('current_id', 0)
remaining_ids = g['prediction'][current_id:]
saved_ids = g['proofread'][:].tolist() if 'proofread' in g else []
return remaining_ids, saved_ids, current_id
def load_false_merges(f):
false_merges = f['false_merges'][:].tolist() if 'false_merges' in f else []
return false_merges
def run_workflow(project_path):
print("Start muscle proofreading workflow from", project_path)
with h5py.File(project_path, 'r') as f:
attrs = f.attrs
check_fps = attrs.get('check_fps', True)
false_merges = load_false_merges(f)
g = f['false_positives'] if check_fps else f['false_negatives']
remaining_ids, saved_ids, current_id = load_state(g)
state = {'check_fps': check_fps, 'current_id': current_id}
print("Continue workflow for", "false positives" if check_fps else "false negatives", "from id", current_id)
done = check_ids(remaining_ids, saved_ids, false_merges, project_path, state)
if check_fps and done:
with h5py.File(project_path, 'r') as f:
g = f['false_negatives']
remaining_ids, saved_ids, current_id = load_state(g)
state = {'check_fps': False, 'current_id': current_id}
print("Start workflow for false negatives from id", current_id)
check_ids(remaining_ids, saved_ids, false_merges, project_path, state)
#! /g/arendt/pape/miniconda3/envs/platybrowser/bin/python
from scripts.segmentation.muscle import predict_muscle_mapping, run_workflow
PROJECT_PATH = '/g/kreshuk/pape/Work/muscle_mapping_v1.h5'
ROOT = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.5.1/tables'
def precompute():
predict_muscle_mapping(ROOT, PROJECT_PATH)
def proofreading():
run_workflow(PROJECT_PATH)
if __name__ == '__main__':
# precompute()
proofreading()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment