From 38117e7ebd0348e55f92240045e919469b5859a2 Mon Sep 17 00:00:00 2001
From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de>
Date: Wed, 2 Oct 2019 10:14:06 +0200
Subject: [PATCH] Update scripts for intensity correction and downscaling

---
 analysis/correct_intensities.py               | 36 +++++++++++++++----
 scripts/export/export_segmentation.py         |  2 +-
 .../transformation/intensity_correction.py    | 14 +++++---
 3 files changed, 41 insertions(+), 11 deletions(-)

diff --git a/analysis/correct_intensities.py b/analysis/correct_intensities.py
index c3a14bb..477ac7e 100644
--- a/analysis/correct_intensities.py
+++ b/analysis/correct_intensities.py
@@ -4,7 +4,10 @@ import numpy as np
 import h5py
 import z5py
 import vigra
+
+from scipy.ndimage.morphology import binary_dilation
 from scripts.transformation import intensity_correction
+from pybdv import make_bdv
 
 
 def combine_mask():
@@ -18,6 +21,7 @@ def combine_mask():
     with h5py.File(mask_path1, 'r') as f:
         key = 't00000/s00/0/cells'
         mask1 = f[key][:].astype('bool')
+        mask1 = binary_dilation(mask1, iterations=4)
     print("Load resin mask ..")
     with h5py.File(mask_path2, 'r') as f:
         key = 't00000/s00/1/cells'
@@ -27,13 +31,13 @@ def combine_mask():
     mask2 = vigra.sampling.resize(mask2.astype('float32'), mask1.shape, order=0).astype('bool')
     mask = np.logical_or(mask1, mask2).astype('uint8')
 
-    out_path = 'mask.n5'
-    out_key = 'data'
-    with z5py.File(out_path) as f:
-        f.create_dataset(out_key, data=mask, compression='gzip', n_threads=8)
+    res = [.4, .32, .32]
+    ds_factors = [[2, 2, 2], [2, 2, 2], [2, 2, 2]]
+    make_bdv(mask, 'mask.h5', ds_factors,
+             resolution=res, unit='micrometer')
 
 
-def correct_intensities(target='slurm', max_jobs=250):
+def correct_intensities_test(target='local', max_jobs=32):
     raw_path = '../../EM-Prospr/em-raw-samplexy.h5'
     tmp_folder = './tmp_intensity_correction'
 
@@ -51,6 +55,26 @@ def correct_intensities(target='slurm', max_jobs=250):
                          target=target, max_jobs=max_jobs)
 
 
+def correct_intensities(target='slurm', max_jobs=250):
+    raw_path = '../data/rawdata/sbem-6dpf-1-whole-raw.h5'
+    tmp_folder = './tmp_intensity_correction'
+
+    mask_path = 'mask.h5'
+    mask_key = 't00000/s00/0/cells'
+
+    out_path = 'em-raw-wholecorrected.h5'
+
+    # trafo = './new_vals.csv'
+    trafo = './new_vals.json'
+    tmp_path = '/g/kreshuk/pape/Work/platy_tmp.n5'
+
+    resolution = [0.025, 0.01, 0.01]
+    intensity_correction(raw_path, out_path, mask_path, mask_key,
+                         trafo, tmp_folder, resolution,
+                         target=target, max_jobs=max_jobs,
+                         tmp_path=tmp_path)
+
+
 if __name__ == '__main__':
     # combine_mask()
-    correct_intensities('local', 32)
+    correct_intensities('local', 64)
diff --git a/scripts/export/export_segmentation.py b/scripts/export/export_segmentation.py
index e9a1d8f..df64f95 100644
--- a/scripts/export/export_segmentation.py
+++ b/scripts/export/export_segmentation.py
@@ -24,7 +24,7 @@ def get_scale_factors(paintera_path, paintera_key):
         rel_factor = [int(sf // prev) for sf, prev in zip(factor, scale_factors[-1])]
         scale_factors.append(factor)
         rel_scales.append(rel_factor[::-1])
-    return rel_scales
+    return rel_scales[1:]
 
 
 def downscale(path, in_key, out_key,
diff --git a/scripts/transformation/intensity_correction.py b/scripts/transformation/intensity_correction.py
index 281329c..8ab1666 100644
--- a/scripts/transformation/intensity_correction.py
+++ b/scripts/transformation/intensity_correction.py
@@ -42,7 +42,7 @@ def downsample(ref_path, in_path, in_key, out_path, resolution,
 
             s0 = ds0.shape
             s1 = ds1.shape
-            factor = [sh0 // sh1 for sh0, sh1 in zip(s0, s1)]
+            factor = [int(round(float(sh0) / sh1, 0)) for sh0, sh1 in zip(s0, s1)]
 
             sample_factors.append(factor)
         assert len(sample_factors) == len(levels) - 1
@@ -50,7 +50,7 @@ def downsample(ref_path, in_path, in_key, out_path, resolution,
     config_dir = os.path.join(tmp_folder, 'configs')
     task = DownscalingWorkflow
     config = task.get_config()['downscaling']
-    config.update({'library': 'skimage'})
+    config.update({'library': 'skimage', 'time_limit': 240, 'mem_limit': 4})
     with open(os.path.join(config_dir, 'downscaling.config'), 'w') as f:
         json.dump(config, f)
 
@@ -85,7 +85,7 @@ def downsample(ref_path, in_path, in_key, out_path, resolution,
 
 def intensity_correction(in_path, out_path, mask_path, mask_key,
                          trafo_path, tmp_folder, resolution,
-                         target='slurm', max_jobs=250):
+                         target='slurm', max_jobs=250, tmp_path=None):
     trafo_ext = os.path.splitext(trafo_path)[1]
     if trafo_ext == '.csv':
         trafo_path = csv_to_json(trafo_path)
@@ -98,10 +98,16 @@ def intensity_correction(in_path, out_path, mask_path, mask_key,
     config_dir = os.path.join(tmp_folder, 'configs')
     write_default_global_config(config_dir)
 
-    tmp_path = os.path.join(tmp_folder, 'data.n5')
+    if tmp_path is None:
+        tmp_path = os.path.join(tmp_folder, 'data.n5')
     tmp_key = 'data'
 
     task = LinearTransformationWorkflow
+    conf = task.get_config()['linear']
+    conf.update({'time_limit': 600, 'mem_limit': 4})
+    with open(os.path.join(config_dir, 'linear.config'), 'w') as f:
+        json.dump(conf, f)
+
     t = task(tmp_folder=tmp_folder, config_dir=config_dir,
              target=target, max_jobs=max_jobs,
              input_path=in_path, input_key=key,
-- 
GitLab