Skip to content
Snippets Groups Projects
Commit 261ccb84 authored by Constantin Pape's avatar Constantin Pape
Browse files

Enable setting registration mode in apply_registration

parent 9ae7a62d
No related branches found
No related tags found
No related merge requests found
#! /g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs/platybrowser/bin/python
import argparse
import os
import json
import random
from shutil import rmtree
import luigi
from scripts.extension.registration import ApplyRegistrationLocal, ApplyRegistrationSlurm
def apply_registration(input_path, output_path, transformation_file,
interpolation='nearest', output_format='tif', result_dtype='unsigned char',
target='local'):
task = ApplyRegistrationSlurm if target == 'slurm' else ApplyRegistrationLocal
assert result_dtype in task.result_types
assert interpolation in task.interpolation_modes
rand_id = hash(random.uniform(0, 1000000))
tmp_folder = 'tmp_%i' % rand_id
config_dir = os.path.join(tmp_folder, 'configs')
os.makedirs(config_dir, exist_ok=True)
shebang = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs/platybrowser/bin/python'
conf = task.default_global_config()
conf.update({'shebang': shebang})
with open(os.path.join(config_dir, 'global.config'), 'w') as f:
json.dump(conf, f)
task_config = task.default_task_config()
task_config.update({'mem_limit': 16, 'time_limit': 240, 'threads_per_job': 4,
'ResultImagePixelType': result_dtype})
with open(os.path.join(config_dir, 'apply_registration.config'), 'w') as f:
json.dump(task_config, f)
in_file = os.path.join(tmp_folder, 'inputs.json')
with open(in_file, 'w') as f:
json.dump([input_path], f)
out_file = os.path.join(tmp_folder, 'outputs.json')
with open(out_file, 'w') as f:
json.dump([output_path], f)
t = task(tmp_folder=tmp_folder, config_dir=config_dir, max_jobs=1,
input_path_file=in_file, output_path_file=out_file,
transformation_file=transformation_file, output_format=output_format,
interpolation=interpolation)
ret = luigi.build([t], local_scheduler=True)
if not ret:
raise RuntimeError("Apply registration failed")
rmtree(tmp_folder)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Apply registration to input tif file')
parser.add_argument('input_path', type=str,
help="Path to input image volume. Must be tiff and have resolution information.")
parser.add_argument('output_path', type=str, help="Path to output.")
parser.add_argument('transformation_file', type=str, help="Path to transformation to apply.")
parser.add_argument('--interpolation', type=str, default='nearest',
help="Interpolation order that will be used. Can be 'nearest' or 'linear'.")
parser.add_argument('--output_format', type=str, default='tif',
help="Output file format. Can be 'tif' or 'xml'.")
parser.add_argument('--result_dtype', type=str, default='unsigned char',
help="Image datatype. Can be 'unsigned char' or 'unsigned short'.")
parser.add_argument('--target', type=str, default='local',
help="Where to run the computation. Can be 'local' or 'slurm'.")
args = parser.parse_args()
apply_registration(args.input_path, args.output_path, args.transformation_file,
args.interpolation, args.output_format, args.result_dtype,
args.target)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
import json import json
import sys import sys
from glob import glob
from subprocess import check_output, CalledProcessError from subprocess import check_output, CalledProcessError
import luigi import luigi
...@@ -19,6 +20,11 @@ class ApplyRegistrationBase(luigi.Task): ...@@ -19,6 +20,11 @@ class ApplyRegistrationBase(luigi.Task):
default_elastix = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/software/elastix_v4.8' default_elastix = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/software/elastix_v4.8'
formats = ('bdv', 'tif') formats = ('bdv', 'tif')
# what about cubic etc?
interpolation_modes = {'linear': 'FinalLinearInterpolator',
'nearest': 'FinalNearestNeighborInterpolator'}
result_types = ('unsigned char', 'unsigned short')
task_name = 'apply_registration' task_name = 'apply_registration'
src_file = os.path.abspath(__file__) src_file = os.path.abspath(__file__)
allow_retry = False allow_retry = False
...@@ -26,6 +32,7 @@ class ApplyRegistrationBase(luigi.Task): ...@@ -26,6 +32,7 @@ class ApplyRegistrationBase(luigi.Task):
input_path_file = luigi.Parameter() input_path_file = luigi.Parameter()
output_path_file = luigi.Parameter() output_path_file = luigi.Parameter()
transformation_file = luigi.Parameter() transformation_file = luigi.Parameter()
interpolation = luigi.Parameter(default='nearest')
output_format = luigi.Parameter(default='bdv') output_format = luigi.Parameter(default='bdv')
fiji_executable = luigi.Parameter(default=default_fiji) fiji_executable = luigi.Parameter(default=default_fiji)
elastix_directory = luigi.Parameter(default=default_elastix) elastix_directory = luigi.Parameter(default=default_elastix)
...@@ -34,6 +41,52 @@ class ApplyRegistrationBase(luigi.Task): ...@@ -34,6 +41,52 @@ class ApplyRegistrationBase(luigi.Task):
def requires(self): def requires(self):
return self.dependency return self.dependency
@staticmethod
def default_task_config():
config = LocalTask.default_task_config()
config.update({'ResultImagePixelType': None})
return config
# update the transformation with our interpolation mode
# and the corresponding dtype
def update_transformation(self, in_file, out_file, res_type):
interpolator_name = self.interpolation_modes[self.interpolation]
def update_line(line, to_write):
line = line.rstrip('\n')
line = line.split()
line[1] = "\"%s\")" % to_write
line = " ".join(line) + "\n"
return line
with open(in_file, 'r') as f_in, open(out_file, 'w') as f_out:
for line in f_in:
# change the interpolator
if line.startswith("(ResampleInterpolator"):
line = update_line(line, interpolator_name)
# change the pixel result type
elif line.startswith("(ResultImagePixelType") and res_type is not None:
line = update_line(line, res_type)
f_out.write(line)
def update_transformations(self, res_type):
trafo_folder, trafo_name = os.path.split(self.transformation_file)
assert trafo_name.startswith('TransformParameters')
trafo_files = glob(os.path.join(trafo_folder, 'TransformParameters*'))
out_folder = os.path.join(self.tmp_folder, 'transformations')
os.makedirs(out_folder, exist_ok=True)
for trafo in trafo_files:
name = os.path.split(trafo)[1]
out = os.path.join(out_folder, name)
self.update_transformation(trafo, out, res_type)
new_trafo = os.path.join(out_folder, trafo_name)
assert os.path.exists(new_trafo)
return new_trafo
def run_impl(self): def run_impl(self):
# get the global config and init configs # get the global config and init configs
shebang = self.global_config_values()[0] shebang = self.global_config_values()[0]
...@@ -51,18 +104,27 @@ class ApplyRegistrationBase(luigi.Task): ...@@ -51,18 +104,27 @@ class ApplyRegistrationBase(luigi.Task):
assert os.path.exists(self.transformation_file) assert os.path.exists(self.transformation_file)
assert os.path.exists(self.fiji_executable) assert os.path.exists(self.fiji_executable)
assert os.path.exists(self.elastix_directory) assert os.path.exists(self.elastix_directory)
assert self.output_format in (self.formats) assert self.output_format in self.formats
assert self.interpolation in self.interpolation_modes
config = self.get_task_config()
res_type = config.pop('res_type', None)
# TODO what are valid res types?
if res_type is not None:
assert res_type in self.result_types
trafo_file = self.update_transformations(res_type)
# get the split of file-ids to the volume # get the split of file-ids to the volume
file_list = vu.blocks_in_volume((n_files,), (1,)) file_list = vu.blocks_in_volume((n_files,), (1,))
# we don't need any additional config besides the paths # we don't need any additional config besides the paths
config = {"input_path_file": self.input_path_file, config.update({"input_path_file": self.input_path_file,
"output_path_file": self.output_path_file, "output_path_file": self.output_path_file,
"transformation_file": self.transformation_file, "transformation_file": trafo_file,
"fiji_executable": self.fiji_executable, "fiji_executable": self.fiji_executable,
"elastix_directory": self.elastix_directory, "elastix_directory": self.elastix_directory,
"tmp_folder": self.tmp_folder, 'output_format': self.output_format} "tmp_folder": self.tmp_folder,
"output_format": self.output_format})
# prime and run the jobs # prime and run the jobs
n_jobs = min(self.max_jobs, n_files) n_jobs = min(self.max_jobs, n_files)
...@@ -108,7 +170,6 @@ def apply_for_file(input_path, output_path, ...@@ -108,7 +170,6 @@ def apply_for_file(input_path, output_path,
assert os.path.exists(tmp_folder) assert os.path.exists(tmp_folder)
assert os.path.exists(input_path) assert os.path.exists(input_path)
assert os.path.exists(transformation_file) assert os.path.exists(transformation_file)
assert os.path.exists(os.path.split(output_path)[0])
if output_format == 'tif': if output_format == 'tif':
format_str = 'Save as Tiff' format_str = 'Save as Tiff'
...@@ -117,14 +178,15 @@ def apply_for_file(input_path, output_path, ...@@ -117,14 +178,15 @@ def apply_for_file(input_path, output_path,
else: else:
assert False, "Invalid output format %s" % output_format assert False, "Invalid output format %s" % output_format
trafo_dir, trafo_name = os.path.split(transformation_file)
# transformix arguments need to be passed as one string, # transformix arguments need to be passed as one string,
# with individual arguments comma separated # with individual arguments comma separated
# the argument to transformaix needs to be one large comma separated string # the argument to transformaix needs to be one large comma separated string
transformix_argument = ["elastixDirectory=\'%s\'" % elastix_directory, transformix_argument = ["elastixDirectory=\'%s\'" % elastix_directory,
"workingDirectory=\'%s\'" % os.path.abspath(tmp_folder), "workingDirectory=\'%s\'" % os.path.abspath(tmp_folder),
"inputImageFile=\'%s\'" % input_path, "inputImageFile=\'%s\'" % os.path.abspath(input_path),
"transformationFile=\'%s\'" % transformation_file, "transformationFile=\'%s\'" % trafo_name,
"outputFile=\'%s\'" % output_path, "outputFile=\'%s\'" % os.path.abspath(output_path),
"outputModality=\'%s\'" % format_str, "outputModality=\'%s\'" % format_str,
"numThreads=\'%i\'" % n_threads] "numThreads=\'%i\'" % n_threads]
transformix_argument = ",".join(transformix_argument) transformix_argument = ",".join(transformix_argument)
...@@ -164,6 +226,15 @@ def apply_for_file(input_path, output_path, ...@@ -164,6 +226,15 @@ def apply_for_file(input_path, output_path,
fu.log("Go back to cwd: %s" % cwd) fu.log("Go back to cwd: %s" % cwd)
os.chdir(cwd) os.chdir(cwd)
if output_format == 'tif':
expected_output = output_path + '-ch0.tif'
elif output_format == 'bdv':
expected_output = output_path + '.xml'
# the elastix plugin has the nasty habit of failing without throwing a proper error code.
# so we check here that we actually have the expected output. if not, something went wrong.
assert os.path.exists(expected_output), "The output %s is not there." % expected_output
def apply_registration(job_id, config_path): def apply_registration(job_id, config_path):
fu.log("start processing job %i" % job_id) fu.log("start processing job %i" % job_id)
......
...@@ -38,8 +38,10 @@ def check_wrapper(): ...@@ -38,8 +38,10 @@ def check_wrapper():
# For now, we use the similarity trafo to save time # For now, we use the similarity trafo to save time
trafo = os.path.join(trafo_dir, 'TransformParameters.Similarity-3Channels.0.txt') trafo = os.path.join(trafo_dir, 'TransformParameters.Similarity-3Channels.0.txt')
interpolation = 'nearest'
t = task(tmp_folder=tmp_folder, config_dir=conf_dir, max_jobs=1, t = task(tmp_folder=tmp_folder, config_dir=conf_dir, max_jobs=1,
input_path_file=in_file, output_path_file=out_file, transformation_file=trafo) input_path_file=in_file, output_path_file=out_file, transformation_file=trafo,
interpolation=interpolation)
ret = luigi.build([t], local_scheduler=True) ret = luigi.build([t], local_scheduler=True)
assert ret assert ret
expected_xml = out_path + '.xml' expected_xml = out_path + '.xml'
......
...@@ -127,9 +127,13 @@ def apply_registration(input_folder, new_folder, ...@@ -127,9 +127,13 @@ def apply_registration(input_folder, new_folder,
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
json.dump(outputs, f) json.dump(outputs, f)
# once we have other sources that are registered to the em space,
# we should expose the interpolation mode
interpolation = 'nearest'
t = task(tmp_folder=tmp_folder, config_dir=config_dir, max_jobs=max_jobs, t = task(tmp_folder=tmp_folder, config_dir=config_dir, max_jobs=max_jobs,
input_path_file=input_file, output_path_file=output_file, input_path_file=input_file, output_path_file=output_file,
transformation_file=transformation_file, output_format='tif') transformation_file=transformation_file, output_format='tif',
interpolation=interpolation)
ret = luigi.build([t], local_scheduler=True) ret = luigi.build([t], local_scheduler=True)
if not ret: if not ret:
raise RuntimeError("Registration failed") raise RuntimeError("Registration failed")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment