Skip to content
Snippets Groups Projects
Commit 0701f5ec authored by Constantin Pape's avatar Constantin Pape
Browse files

Clean up intensity correction workflow

parent 0bc30be5
No related branches found
No related tags found
No related merge requests found
#! /g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs/platybrowser/bin/python #! /g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs/platybrowser/bin/python
import os
import json
from concurrent import futures
import os
import numpy as np import numpy as np
import z5py
import h5py import h5py
import luigi
import vigra import vigra
from scipy.ndimage.morphology import binary_dilation from scipy.ndimage.morphology import binary_dilation
from mmpb.transformation import intensity_correction from mmpb.transformation import intensity_correction
from pybdv import make_bdv from pybdv import make_bdv
from pybdv.metadata import write_h5_metadata
ROOT = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data'
def correct_intensities(target='slurm', max_jobs=250):
raw_path = os.path.join(ROOT, 'rawdata/sbem-6dpf-1-whole-raw.h5')
tmp_folder = './tmp_intensity_correction'
# TODO need to check which mask this is and then take it from the data folder
mask_path = 'mask.h5'
mask_key = 't00000/s00/0/cells'
out_path = 'em-raw-wholecorrected.n5'
trafo = './intensity_correction_parameters.json'
resolution = [0.025, 0.01, 0.01]
intensity_correction(raw_path, out_path, mask_path, mask_key,
trafo, tmp_folder, resolution,
target=target, max_jobs=max_jobs)
def combine_mask(): def combine_mask():
...@@ -42,116 +56,6 @@ def combine_mask(): ...@@ -42,116 +56,6 @@ def combine_mask():
resolution=res, unit='micrometer') resolution=res, unit='micrometer')
def correct_intensities_test(target='local', max_jobs=32):
raw_path = '../../EM-Prospr/em-raw-samplexy.h5'
tmp_folder = './tmp_intensity_correction'
mask_path = 'mask.n5'
mask_key = 'data'
out_path = 'em-raw-samplexy-corrected.h5'
# trafo = './new_vals.csv'
trafo = './new_vals.json'
resolution = [0.025, 0.32, 0.32]
intensity_correction(raw_path, out_path, mask_path, mask_key,
trafo, tmp_folder, resolution,
target=target, max_jobs=max_jobs)
def correct_intensities(target='slurm', max_jobs=250):
raw_path = '../data/rawdata/sbem-6dpf-1-whole-raw.h5'
tmp_folder = './tmp_intensity_correction'
mask_path = 'mask.h5'
mask_key = 't00000/s00/0/cells'
out_path = 'em-raw-wholecorrected.h5'
# trafo = './new_vals.csv'
trafo = './new_vals.json'
tmp_path = '/g/kreshuk/pape/Work/platy_tmp.n5'
resolution = [0.025, 0.01, 0.01]
intensity_correction(raw_path, out_path, mask_path, mask_key,
trafo, tmp_folder, resolution,
target=target, max_jobs=max_jobs,
tmp_path=tmp_path)
def check_chunks():
import nifty.tools as nt
path = '/g/kreshuk/pape/Work/platy_tmp.n5'
key = 'data'
f = z5py.File(path, 'r')
ds = f[key]
shape = ds.shape
chunks = ds.chunks
blocking = nt.blocking([0, 0, 0], shape, chunks)
def check_chunk(block_id):
print("Check", block_id, "/", blocking.numberOfBlocks)
block = blocking.getBlock(block_id)
chunk_id = tuple(beg // ch for beg, ch in zip(block.begin, chunks))
try:
ds.read_chunk(chunk_id)
except RuntimeError:
print("Failed:", chunk_id)
return chunk_id
print("Start checking", blocking.numberOfBlocks, "blocks")
with futures.ThreadPoolExecutor(32) as tp:
tasks = [tp.submit(check_chunk, block_id) for block_id in range(blocking.numberOfBlocks)]
results = [t.result() for t in tasks]
results = [res for res in results if res is not None]
print()
print(results)
print()
with open('./failed_chunks.json', 'w') as f:
json.dump(results, f)
# TODO write hdf5 meta-data
def make_subsampled_volume():
from cluster_tools.copy_volume import CopyVolumeLocal
task = CopyVolumeLocal
p = './em-raw-wholecorrected.h5'
glob_conf = task.default_global_config()
task_conf = task.default_task_config()
tmp_folder = './tmp_copy'
config_dir = os.path.join(tmp_folder, 'configs')
os.makedirs(config_dir, exist_ok=True)
out_path = 'em-raw-small-corrected.h5'
shebang = '/g/kreshuk/pape/Work/software/conda/miniconda3/envs/cluster_env37/bin/python'
block_shape = [64, 512, 512]
chunks = [32, 256, 256]
glob_conf.update({'shebang': shebang, 'block_shape': block_shape})
task_conf.update({'threads_per_job': 32, 'chunks': chunks})
with open(os.path.join(config_dir, 'global.config'), 'w') as f:
json.dump(glob_conf, f)
with open(os.path.join(config_dir, 'copy_volume.config'), 'w') as f:
json.dump(task_conf, f)
for scale in range(5):
pref = 's%i' % scale
in_key = 't00000/s00/%i/cells' % (scale + 3,)
out_key = 't00000/s00/%i/cells' % scale
t = task(tmp_folder=tmp_folder, config_dir=config_dir,
max_jobs=1, input_path=p, input_key=in_key,
output_path=out_path, output_key=out_key,
prefix=pref)
luigi.build([t], local_scheduler=True)
scale_factors = 5 * [[2, 2, 2]]
write_h5_metadata(out_path, scale_factors)
def make_extrapolation_mask(): def make_extrapolation_mask():
z0 = 800 # extrapolation for z < z0 z0 = 800 # extrapolation for z < z0
z1 = 9800 # extraplation for z > z1 z1 = 9800 # extraplation for z > z1
...@@ -184,6 +88,5 @@ def make_extrapolation_mask(): ...@@ -184,6 +88,5 @@ def make_extrapolation_mask():
if __name__ == '__main__': if __name__ == '__main__':
correct_intensities('local', 1) correct_intensities()
# make_subsampled_volume()
# make_extrapolation_mask() # make_extrapolation_mask()
This diff is collapsed.
...@@ -4,8 +4,9 @@ import luigi ...@@ -4,8 +4,9 @@ import luigi
import pandas as pd import pandas as pd
from elf.io import open_file from elf.io import open_file
from pybdv.util import get_key
from cluster_tools.transformations import LinearTransformationWorkflow from cluster_tools.transformations import LinearTransformationWorkflow
from cluster_tools.downscaling import DownscalingWorkflow, PainteraToBdvWorkflow from cluster_tools.downscaling import DownscalingWorkflow
from ..default_config import write_default_global_config from ..default_config import write_default_global_config
...@@ -30,15 +31,19 @@ def validate_trafo(trafo_path, in_path, in_key): ...@@ -30,15 +31,19 @@ def validate_trafo(trafo_path, in_path, in_key):
def downsample(ref_path, in_path, in_key, out_path, resolution, def downsample(ref_path, in_path, in_key, out_path, resolution,
tmp_folder, target, max_jobs): tmp_folder, target, max_jobs):
ref_is_h5 = os.path.splitext(ref_path)[1].lower() in ('hdf5', 'h5')
gkey = get_key(ref_is_h5, 0, 0)
with open_file(ref_path, 'r') as f: with open_file(ref_path, 'r') as f:
g = f['t00000/s00'] g = f[gkey]
levels = list(g.keys()) levels = list(g.keys())
levels.sort() levels.sort()
sample_factors = [] sample_factors = []
for level in range(1, len(levels)): for level in range(1, len(levels)):
ds0 = g['%s/cells' % levels[level - 1]] k0 = get_key(ref_is_h5, 0, 0, level - 1)
ds1 = g['%s/cells' % levels[level]] k1 = get_key(ref_is_h5, 0, 0, level)
ds0 = f[k0]
ds1 = f[k1]
s0 = ds0.shape s0 = ds0.shape
s1 = ds1.shape s1 = ds1.shape
...@@ -54,30 +59,17 @@ def downsample(ref_path, in_path, in_key, out_path, resolution, ...@@ -54,30 +59,17 @@ def downsample(ref_path, in_path, in_key, out_path, resolution,
with open(os.path.join(config_dir, 'downscaling.config'), 'w') as f: with open(os.path.join(config_dir, 'downscaling.config'), 'w') as f:
json.dump(config, f) json.dump(config, f)
tmp_key2 = 'downscaled'
halos = len(sample_factors) * [[0, 0, 0]] halos = len(sample_factors) * [[0, 0, 0]]
# TODO this needs merge of
# https://github.com/constantinpape/cluster_tools/pull/17
t = task(tmp_folder=tmp_folder, config_dir=config_dir, t = task(tmp_folder=tmp_folder, config_dir=config_dir,
max_jobs=max_jobs, target=target, max_jobs=max_jobs, target=target,
input_path=in_path, input_key=in_key, input_path=in_path, input_key=in_key,
scale_factors=sample_factors, halos=halos, scale_factors=sample_factors, halos=halos,
metadata_format='paintera', metadata_dict={'resolution': resolution}, metadata_format='bdv', metadata_dict={'resolution': resolution,
output_path=in_path, output_key_prefix=tmp_key2) 'unit': 'micrometer'},
ret = luigi.build([t], local_scheduler=True) output_path=out_path)
if not ret:
raise RuntimeError("Downscaling failed")
task = PainteraToBdvWorkflow
config = task.get_config()['copy_volume']
config.update({'threads_per_job': 32, 'mem_limit': 64, 'time_limit': 1400})
with open(os.path.join(config_dir, 'copy_volume.config'), 'w') as f:
json.dump(config, f)
t = task(tmp_folder=tmp_folder, config_dir=config_dir,
max_jobs=1, target=target,
input_path=in_path, input_key_prefix=tmp_key2,
output_path=out_path,
metadata_dict={'resolution': resolution})
ret = luigi.build([t], local_scheduler=True) ret = luigi.build([t], local_scheduler=True)
if not ret: if not ret:
raise RuntimeError("Downscaling failed") raise RuntimeError("Downscaling failed")
...@@ -85,23 +77,23 @@ def downsample(ref_path, in_path, in_key, out_path, resolution, ...@@ -85,23 +77,23 @@ def downsample(ref_path, in_path, in_key, out_path, resolution,
def intensity_correction(in_path, out_path, mask_path, mask_key, def intensity_correction(in_path, out_path, mask_path, mask_key,
trafo_path, tmp_folder, resolution, trafo_path, tmp_folder, resolution,
target='slurm', max_jobs=250, tmp_path=None): target='slurm', max_jobs=250):
trafo_ext = os.path.splitext(trafo_path)[1] trafo_ext = os.path.splitext(trafo_path)[1]
if trafo_ext == '.csv': if trafo_ext == '.csv':
trafo_path = csv_to_json(trafo_path) trafo_path = csv_to_json(trafo_path)
elif trafo_ext != '.json': elif trafo_ext != '.json':
raise ValueError("Expect trafo as json.") raise ValueError("Expect trafo as json.")
key = 't00000/s00/0/cells' in_is_h5 = os.path.splitext(in_path)[1].lower() in ('.h5', '.hdf5')
out_is_h5 = os.path.splitext(in_path)[1].lower() in ('.h5', '.hdf5')
key = get_key(in_is_h5, 0, 0, 0)
out_key = get_key(out_is_h5, 0, 0, 0)
validate_trafo(trafo_path, in_path, key) validate_trafo(trafo_path, in_path, key)
config_dir = os.path.join(tmp_folder, 'configs') config_dir = os.path.join(tmp_folder, 'configs')
write_default_global_config(config_dir) write_default_global_config(config_dir)
if tmp_path is None:
tmp_path = os.path.join(tmp_folder, 'data.n5')
tmp_key = 'data'
task = LinearTransformationWorkflow task = LinearTransformationWorkflow
conf = task.get_config()['linear'] conf = task.get_config()['linear']
conf.update({'time_limit': 360, 'mem_limit': 8}) conf.update({'time_limit': 360, 'mem_limit': 8})
...@@ -112,11 +104,11 @@ def intensity_correction(in_path, out_path, mask_path, mask_key, ...@@ -112,11 +104,11 @@ def intensity_correction(in_path, out_path, mask_path, mask_key,
target=target, max_jobs=max_jobs, target=target, max_jobs=max_jobs,
input_path=in_path, input_key=key, input_path=in_path, input_key=key,
mask_path=mask_path, mask_key=mask_key, mask_path=mask_path, mask_key=mask_key,
output_path=tmp_path, output_key=tmp_key, output_path=out_path, output_key=out_key,
transformation=trafo_path) transformation=trafo_path)
ret = luigi.build([t], local_scheduler=True) ret = luigi.build([t], local_scheduler=True)
if not ret: if not ret:
raise RuntimeError("Transformation failed") raise RuntimeError("Transformation failed")
downsample(in_path, tmp_path, tmp_key, out_path, resolution, downsample(in_path, out_path, out_key, out_path, resolution,
tmp_folder, target, max_jobs) tmp_folder, target, max_jobs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment