From 589775d88be73f81e3847f92ccbf41f199d8c696 Mon Sep 17 00:00:00 2001
From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de>
Date: Mon, 2 Sep 2019 11:13:05 +0200
Subject: [PATCH] Start to implment muscle correction

---
 scripts/segmentation/muscle/__init__.py       |   2 +
 scripts/segmentation/muscle/muscle_mapping.py | 127 ++++++++++++++
 scripts/segmentation/muscle/workflow.py       | 160 ++++++++++++++++++
 segmentation/muscles.py                       |  19 +++
 4 files changed, 308 insertions(+)
 create mode 100644 scripts/segmentation/muscle/__init__.py
 create mode 100644 scripts/segmentation/muscle/muscle_mapping.py
 create mode 100644 scripts/segmentation/muscle/workflow.py
 create mode 100644 segmentation/muscles.py

diff --git a/scripts/segmentation/muscle/__init__.py b/scripts/segmentation/muscle/__init__.py
new file mode 100644
index 0000000..08487e7
--- /dev/null
+++ b/scripts/segmentation/muscle/__init__.py
@@ -0,0 +1,2 @@
+from.muscle_mapping import predict_muscle_mapping
+from .workflow import run_workflow
diff --git a/scripts/segmentation/muscle/muscle_mapping.py b/scripts/segmentation/muscle/muscle_mapping.py
new file mode 100644
index 0000000..4fa6c52
--- /dev/null
+++ b/scripts/segmentation/muscle/muscle_mapping.py
@@ -0,0 +1,127 @@
+import os
+import h5py
+import numpy as np
+import pandas as pd
+
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.model_selection import StratifiedKFold
+
+
+def compute_labels(root):
+    table = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', 'regions.csv')
+    table = pd.read_csv(table, sep='\t')
+    label_ids = table['label_id'].values
+    labels = table['muscle'].values.astype('uint8')
+    assert np.array_equal(np.unique(labels), [0, 1])
+    return labels, label_ids
+
+
+def compute_features(root):
+
+    feature_names = []
+
+    # we take the number of pixels and calculat the size of
+    # the bounding bpx from the default table
+    default = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', 'default.csv')
+    default = pd.read_csv(default, sep='\t')
+
+    n_pixels = default['n_pixels'].values
+    bb_min = np.array([default['bb_min_z'].values,
+                       default['bb_min_y'].values,
+                       default['bb_min_x'].values]).T
+    bb_max = np.array([default['bb_max_z'].values,
+                       default['bb_max_y'].values,
+                       default['bb_max_x'].values]).T
+    bb_shape = bb_max - bb_min
+    features_def = np.concatenate([n_pixels[:, None], bb_shape], axis=1)
+    feature_names.extend(['n_pixels', 'bb_shape_z', 'bb_shape_y', 'bb_shape_x'])
+
+    morpho = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels', 'morphology.csv')
+    morpho = pd.read_csv(morpho, sep='\t')
+    label_ids_morpho = morpho['label_id'].values.astype('uint64')
+    features_morpho = morpho[['shape_volume_in_microns', 'shape_extent',
+                              'shape_surface_area', 'shape_sphericity']].values
+    feature_names.extend(['volume', 'extent', 'surface_area', 'sphericty'])
+
+    # add the nucleus features
+    nucleus_mapping = os.path.join(root, 'sbem-6dpf-1-whole-segmented-cells-labels',
+                                   'cells_to_nuclei.csv')
+    nucleus_mapping = pd.read_csv(nucleus_mapping, sep='\t')
+    label_ids_nuclei = nucleus_mapping['label_id'].values.astype('uint64')
+    nucleus_ids = nucleus_mapping['nucleus_id'].values.astype('uint64')
+    nucleus_mask = nucleus_ids > 0
+    nucleus_ids = nucleus_ids[nucleus_mask]
+    label_ids_nuclei = label_ids_nuclei[nucleus_mask]
+
+    nucleus_features = os.path.join(root, 'sbem-6dpf-1-whole-segmented-nuclei-labels',
+                                    'morphology.csv')
+    nucleus_features = pd.read_csv(nucleus_features, sep='\t')
+    nucleus_features.set_index('label_id', inplace=True)
+    nucleus_features = nucleus_features.loc[nucleus_ids]
+    nucleus_features = nucleus_features[['shape_volume_in_microns', 'shape_extent',
+                                         'shape_surface_area', 'shape_sphericity',
+                                         'intensity_mean', 'intensity_st_dev']].values
+    feature_names.extend(['nucleus_volume', 'nucleus_extent', 'nucleus_surface_area',
+                          'nucleus_sphericity', 'nucleus_intensity_mean', 'nucleus_intensity_std'])
+
+    # TODO
+    # combine the features
+    label_id_mask = np.zeros(len(default), dtype='uint8')
+    label_id_mask[label_ids_morpho] += 1
+    label_id_mask[label_ids_nuclei] += 1
+    label_id_mask = label_id_mask > 1
+    valid_label_ids = np.where(label_id_mask)[0]
+
+    label_id_mask_morpho = np.isin(label_ids_morpho, valid_label_ids)
+    label_id_mask_nuclei = np.isin(label_ids_nuclei, valid_label_ids)
+
+    features = np.concatenate([features_def[label_id_mask],
+                               features_morpho[label_id_mask_morpho],
+                               nucleus_features[label_id_mask_nuclei]], axis=1)
+    assert len(features) == len(valid_label_ids), "%i, %i" % (len(features),
+                                                              len(valid_label_ids))
+    return features, valid_label_ids, feature_names
+
+
+def compute_features_and_labels(root):
+    features, label_ids, _ = compute_features(root)
+    labels, _ = compute_labels(root)
+
+    labels = labels[label_ids]
+    assert len(labels) == len(features) == len(label_ids)
+    return features, labels, label_ids
+
+
+def predict_muscle_mapping(root, project_path, n_folds=4):
+    print("Computig labels and features ...")
+    features, labels, label_ids = compute_features_and_labels(root)
+    print("Found", len(features), "samples and", features.shape[1], "features per sample")
+    kf = StratifiedKFold(n_splits=n_folds, shuffle=True)
+    false_pos_ids = []
+    false_neg_ids = []
+
+    print("Find false positives and negatives on", n_folds, "folds")
+    for train_idx, test_idx in kf.split(features, labels):
+        x, y = features[train_idx], labels[train_idx]
+        rf = RandomForestClassifier(n_estimators=50, n_jobs=8)
+        rf.fit(x, y)
+
+        x, y = features[test_idx], labels[test_idx]
+        # allow adapting the threshold ?
+        pred = rf.predict(x)
+
+        false_pos = np.logical_and(pred == 1, y == 0)
+        false_neg = np.logical_and(pred == 0, y == 1)
+
+        fold_labels = label_ids[test_idx]
+        false_pos_ids.extend(fold_labels[false_pos].tolist())
+        false_neg_ids.extend(fold_labels[false_neg].tolist())
+
+    print("Found", len(false_pos_ids), "false positves")
+    print("Found", len(false_neg_ids), "false negatitves")
+
+    # save false pos and false negatives
+    print("Serializing results to", project_path)
+    with h5py.File(project_path) as f:
+        f.create_dataset('false_positives/prediction', data=false_pos_ids)
+        f.create_dataset('false_negatives/prediction', data=false_neg_ids)
diff --git a/scripts/segmentation/muscle/workflow.py b/scripts/segmentation/muscle/workflow.py
new file mode 100644
index 0000000..abae948
--- /dev/null
+++ b/scripts/segmentation/muscle/workflow.py
@@ -0,0 +1,160 @@
+import os
+import h5py
+import z5py
+import pandas as pd
+import napari
+from heimdall import view, to_source
+from elf.wrapper.resized_volume import ResizedVolume
+
+
+def scale_to_res(scale):
+    res = {1: [.025, .02, .02],
+           2: [.05, .04, .04],
+           3: [.1, .08, .08],
+           4: [.2, .16, .16],
+           5: [.4, .32, .32]}
+    return res[scale]
+
+
+def get_bb(table, lid, res):
+    row = table.loc[lid]
+    bb_min = [row.bb_min_z, row.bb_min_y, row.bb_min_x]
+    bb_max = [row.bb_max_z, row.bb_max_y, row.bb_max_x]
+    return tuple(slice(int(mi / re), int(ma / re)) for mi, ma, re in zip(bb_min, bb_max, res))
+
+
+def view_candidate(raw, mask, muscle):
+    save_id, false_merge, save_state, done = False, False, False, False
+    with napari.gui_qt():
+        viewer = view(to_source(raw, name='raw'),
+                      to_source(mask, name='prediction'),
+                      to_source(muscle, name='muscle-segmentation'),
+                      return_viewer=True)
+
+        # add key bindings
+        @viewer.bind_key('y')
+        def confirm_id(viewer):
+            print("Confirm id requested")
+            nonlocal save_id
+            save_id = True
+
+        @viewer.bind_key('f')
+        def add_false_merge(viewer):
+            print("False merge requested")
+            nonlocal false_merge
+            false_merge = True
+
+        @viewer.bind_key('s')
+        def save(viewer):
+            print("Save state requested")
+            nonlocal save_state
+            save_state = True
+
+        @viewer.bind_key('q')
+        def quit_(viewer):
+            print("Quit requested")
+            nonlocal done
+            done = True
+
+    return save_id, false_merge, save_state, done
+
+
+def check_ids(remaining_ids, saved_ids, false_merges, project_path, state):
+    scale = 3
+    pathr = '/g/kreshuk/data/arendt/platyneris_v1/data.n5'
+    paths = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.3.1',
+                         'segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5')
+    table_path = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.3.1',
+                              'tables/sbem-6dpf-1-whole-segmented-cells-labels/default.csv')
+
+    table = pd.read_csv(table_path, sep='\t')
+    res = scale_to_res(scale)
+
+    fr = z5py.File(pathr, 'r')
+    dsr = fr['volumes/raw/s%i' % scale]
+    km = 'volumes/labels/muscle'
+    dsm = fr[km]
+    dsm = ResizedVolume(dsm, shape=dsr.shape)
+    assert dsm.shape == dsr.shape
+
+    check_fps, current_id = state['check_fps'], state['current_id']
+
+    with h5py.File(paths, 'r') as fs:
+        dss = fs['t00000/s00/%i/cells' % (scale - 1,)]
+
+        for ii, fid in enumerate(remaining_ids):
+            bb = get_bb(table, fid, res)
+            if check_fps:
+                print("Checking false positives - id:", fid)
+            else:
+                print("Checking false negatives - id:", fid)
+            raw = dsr[bb]
+            seg = dss[bb]
+            muscle = dsm[bb]
+            muscle = (muscle > 0).astype('uint32')
+            mask = (seg == fid).astype('uint32')
+            save_id, false_merge, save_state, done = view_candidate(raw, mask, muscle)
+
+            if save_id:
+                saved_ids.append(fid)
+                print("Confirm id", fid, "we now have", len(saved_ids), "confirmed ids.")
+
+            if false_merge:
+                print("Add id", fid, "to false merges")
+                false_merges.append(fid)
+
+            if save_state:
+                print("Save current state to", project_path)
+                with h5py.File(project_path) as f:
+                    f.attrs['check_fps'] = check_fps
+
+                    if 'false_merges' in f:
+                        del f['false_merges']
+                    if len(false_merges) > 0:
+                        f.create_dataset('false_merges', data=false_merges)
+
+                    g = f['false_positives'] if check_fps else f['false_negatives']
+                    g.attrs['current_id'] = current_id + ii + 1
+                    if 'proofread' in g:
+                        del g['proofread']
+                    if len(saved_ids) > 0:
+                        g.create_dataset('proofread', data=saved_ids)
+
+            if done:
+                print("Quit")
+                return False
+    return True
+
+
+def load_state(g):
+    current_id = g.attrs.get('current_id', 0)
+    remaining_ids = g['prediction'][current_id:]
+    saved_ids = g['proofread'][:].tolist() if 'proofread' in g else []
+    return remaining_ids, saved_ids, current_id
+
+
+def load_false_merges(f):
+    false_merges = f['false_merges'][:].tolist() if 'false_merges' in f else []
+    return false_merges
+
+
+def run_workflow(project_path):
+    print("Start  muscle proofreading workflow from", project_path)
+    with h5py.File(project_path, 'r') as f:
+        attrs = f.attrs
+        check_fps = attrs.get('check_fps', True)
+        false_merges = load_false_merges(f)
+        g = f['false_positives'] if check_fps else f['false_negatives']
+        remaining_ids, saved_ids, current_id = load_state(g)
+
+    state = {'check_fps': check_fps, 'current_id': current_id}
+    print("Continue workflow for", "false positives" if check_fps else "false negatives", "from id", current_id)
+    done = check_ids(remaining_ids, saved_ids, false_merges, project_path, state)
+
+    if check_fps and done:
+        with h5py.File(project_path, 'r') as f:
+            g = f['false_negatives']
+            remaining_ids, saved_ids, current_id = load_state(g)
+        state = {'check_fps': False, 'current_id': current_id}
+        print("Start workflow for false negatives from id", current_id)
+        check_ids(remaining_ids, saved_ids, false_merges, project_path, state)
diff --git a/segmentation/muscles.py b/segmentation/muscles.py
new file mode 100644
index 0000000..067f242
--- /dev/null
+++ b/segmentation/muscles.py
@@ -0,0 +1,19 @@
+#! /g/arendt/pape/miniconda3/envs/platybrowser/bin/python
+from scripts.segmentation.muscle import predict_muscle_mapping, run_workflow
+
+
+PROJECT_PATH = '/g/kreshuk/pape/Work/muscle_mapping_v1.h5'
+ROOT = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data/0.5.1/tables'
+
+
+def precompute():
+    predict_muscle_mapping(ROOT, PROJECT_PATH)
+
+
+def proofreading():
+    run_workflow(PROJECT_PATH)
+
+
+if __name__ == '__main__':
+    # precompute()
+    proofreading()
-- 
GitLab