From 12e90e0132d45f74abbf6e2d7a759e52c113c324 Mon Sep 17 00:00:00 2001
From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de>
Date: Tue, 1 Oct 2019 15:58:13 +0200
Subject: [PATCH] Implement workflow for raw data intensity correction

---
 analysis/cell_volumes.py                      |   2 +-
 analysis/correct_intensities.py               |  56 +++++++++
 analysis/gene_expression.py                   |   2 +-
 scripts/transformation/__init__.py            |   1 +
 .../transformation/intensity_correction.py    | 116 ++++++++++++++++++
 5 files changed, 175 insertions(+), 2 deletions(-)
 create mode 100644 analysis/correct_intensities.py
 create mode 100644 scripts/transformation/__init__.py
 create mode 100644 scripts/transformation/intensity_correction.py

diff --git a/analysis/cell_volumes.py b/analysis/cell_volumes.py
index 961270b..7bebba5 100644
--- a/analysis/cell_volumes.py
+++ b/analysis/cell_volumes.py
@@ -1,4 +1,4 @@
-#! /g/arendt/pape/miniconda3/envs/platybrowser/bin/python
+#! /g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs/platybrowser/bin/python
 
 import argparse
 import os
diff --git a/analysis/correct_intensities.py b/analysis/correct_intensities.py
new file mode 100644
index 0000000..c3a14bb
--- /dev/null
+++ b/analysis/correct_intensities.py
@@ -0,0 +1,56 @@
+#! /g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs/platybrowser/bin/python
+import os
+import numpy as np
+import h5py
+import z5py
+import vigra
+from scripts.transformation import intensity_correction
+
+
+def combine_mask():
+    tmp_folder = './tmp_intensity_correction'
+    os.makedirs(tmp_folder, exist_ok=True)
+
+    mask_path1 = '../data/rawdata/sbem-6dpf-1-whole-mask-inside.h5'
+    mask_path2 = '../data/rawdata/sbem-6dpf-1-whole-mask-resin.h5'
+
+    print("Load inside mask ...")
+    with h5py.File(mask_path1, 'r') as f:
+        key = 't00000/s00/0/cells'
+        mask1 = f[key][:].astype('bool')
+    print("Load resin mask ..")
+    with h5py.File(mask_path2, 'r') as f:
+        key = 't00000/s00/1/cells'
+        mask2 = f[key][:]
+
+    print("Resize resin mask ...")
+    mask2 = vigra.sampling.resize(mask2.astype('float32'), mask1.shape, order=0).astype('bool')
+    mask = np.logical_or(mask1, mask2).astype('uint8')
+
+    out_path = 'mask.n5'
+    out_key = 'data'
+    with z5py.File(out_path) as f:
+        f.create_dataset(out_key, data=mask, compression='gzip', n_threads=8)
+
+
+def correct_intensities(target='slurm', max_jobs=250):
+    raw_path = '../../EM-Prospr/em-raw-samplexy.h5'
+    tmp_folder = './tmp_intensity_correction'
+
+    mask_path = 'mask.n5'
+    mask_key = 'data'
+
+    out_path = 'em-raw-samplexy-corrected.h5'
+
+    # trafo = './new_vals.csv'
+    trafo = './new_vals.json'
+
+    resolution = [0.025, 0.32, 0.32]
+    intensity_correction(raw_path, out_path, mask_path, mask_key,
+                         trafo, tmp_folder, resolution,
+                         target=target, max_jobs=max_jobs)
+
+
+if __name__ == '__main__':
+    # combine_mask()
+    correct_intensities('local', 32)
diff --git a/analysis/gene_expression.py b/analysis/gene_expression.py
index ebd5a86..3c83413 100644
--- a/analysis/gene_expression.py
+++ b/analysis/gene_expression.py
@@ -1,4 +1,4 @@
-#! /g/arendt/pape/miniconda3/envs/platybrowser/bin/python
+#! /g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs/platybrowser/bin/python
 import argparse
 import os
 import json
diff --git a/scripts/transformation/__init__.py b/scripts/transformation/__init__.py
new file mode 100644
index 0000000..2489f45
--- /dev/null
+++ b/scripts/transformation/__init__.py
@@ -0,0 +1 @@
+from .intensity_correction import intensity_correction
diff --git a/scripts/transformation/intensity_correction.py b/scripts/transformation/intensity_correction.py
new file mode 100644
index 0000000..281329c
--- /dev/null
+++ b/scripts/transformation/intensity_correction.py
@@ -0,0 +1,116 @@
+import os
+import json
+import luigi
+import pandas as pd
+
+from elf.io import open_file
+from cluster_tools.transformations import LinearTransformationWorkflow
+from cluster_tools.downscaling import DownscalingWorkflow, PainteraToBdvWorkflow
+from ..default_config import write_default_global_config
+
+
+def csv_to_json(trafo_path):
+    trafo = pd.read_csv(trafo_path, sep='\t')
+    n_slices = trafo.shape[0]
+    trafo = {z: {'a': trafo.loc[z].mult, 'b': trafo.loc[z].offset} for z in range(n_slices)}
+    new_trafo_path = os.path.splitext(trafo_path)[0] + '.json'
+    with open(new_trafo_path, 'w') as f:
+        json.dump(trafo, f)
+    return new_trafo_path
+
+
+def validate_trafo(trafo_path, in_path, in_key):
+    with open_file(in_path, 'r') as f:
+        n_slices = f[in_key].shape[0]
+    with open(trafo_path) as f:
+        trafo = json.load(f)
+    if n_slices != len(trafo):
+        raise RuntimeError("Invalid number of transformations: %i,%i" % (n_slices, len(trafo)))
+
+
+def downsample(ref_path, in_path, in_key, out_path, resolution,
+               tmp_folder, target, max_jobs):
+    with open_file(ref_path, 'r') as f:
+        g = f['t00000/s00']
+        levels = list(g.keys())
+        levels.sort()
+
+        sample_factors = []
+        for level in range(1, len(levels)):
+            ds0 = g['%s/cells' % levels[level - 1]]
+            ds1 = g['%s/cells' % levels[level]]
+
+            s0 = ds0.shape
+            s1 = ds1.shape
+            factor = [sh0 // sh1 for sh0, sh1 in zip(s0, s1)]
+
+            sample_factors.append(factor)
+        assert len(sample_factors) == len(levels) - 1
+
+    config_dir = os.path.join(tmp_folder, 'configs')
+    task = DownscalingWorkflow
+    config = task.get_config()['downscaling']
+    config.update({'library': 'skimage'})
+    with open(os.path.join(config_dir, 'downscaling.config'), 'w') as f:
+        json.dump(config, f)
+
+    tmp_key2 = 'downscaled'
+    halos = len(sample_factors) * [[0, 0, 0]]
+
+    t = task(tmp_folder=tmp_folder, config_dir=config_dir,
+             max_jobs=max_jobs, target=target,
+             input_path=in_path, input_key=in_key,
+             scale_factors=sample_factors, halos=halos,
+             metadata_format='paintera', metadata_dict={'resolution': resolution},
+             output_path=in_path, output_key_prefix=tmp_key2)
+    ret = luigi.build([t], local_scheduler=True)
+    if not ret:
+        raise RuntimeError("Downscaling failed")
+
+    task = PainteraToBdvWorkflow
+    config = task.get_config()['copy_volume']
+    config.update({'threads_per_job': 32, 'mem_limit': 64, 'time_limit': 1400})
+    with open(os.path.join(config_dir, 'copy_volume.config'), 'w') as f:
+        json.dump(config, f)
+
+    t = task(tmp_folder=tmp_folder, config_dir=config_dir,
+             max_jobs=1, target=target,
+             input_path=in_path, input_key_prefix=tmp_key2,
+             output_path=out_path,
+             metadata_dict={'resolution': resolution})
+    ret = luigi.build([t], local_scheduler=True)
+    if not ret:
+        raise RuntimeError("Downscaling failed")
+
+
+def intensity_correction(in_path, out_path, mask_path, mask_key,
+                         trafo_path, tmp_folder, resolution,
+                         target='slurm', max_jobs=250):
+    trafo_ext = os.path.splitext(trafo_path)[1]
+    if trafo_ext == '.csv':
+        trafo_path = csv_to_json(trafo_path)
+    elif trafo_ext != '.json':
+        raise ValueError("Expect trafo as json.")
+
+    key = 't00000/s00/0/cells'
+    validate_trafo(trafo_path, in_path, key)
+
+    config_dir = os.path.join(tmp_folder, 'configs')
+    write_default_global_config(config_dir)
+
+    tmp_path = os.path.join(tmp_folder, 'data.n5')
+    tmp_key = 'data'
+
+    task = LinearTransformationWorkflow
+    t = task(tmp_folder=tmp_folder, config_dir=config_dir,
+             target=target, max_jobs=max_jobs,
+             input_path=in_path, input_key=key,
+             mask_path=mask_path, mask_key=mask_key,
+             output_path=tmp_path, output_key=tmp_key,
+             transformation=trafo_path)
+    ret = luigi.build([t], local_scheduler=True)
+    if not ret:
+        raise RuntimeError("Transformation failed")
+
+    downsample(in_path, tmp_path, tmp_key, out_path, resolution,
+               tmp_folder, target, max_jobs)
-- 
GitLab