diff --git a/registration/apply_registration.py b/registration/apply_registration.py new file mode 100644 index 0000000000000000000000000000000000000000..992a323568b8438791cfdd7122b91366fc893f16 --- /dev/null +++ b/registration/apply_registration.py @@ -0,0 +1,75 @@ +#! /g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs/platybrowser/bin/python +import argparse +import os +import json +import random +from shutil import rmtree + +import luigi +from scripts.extension.registration import ApplyRegistrationLocal, ApplyRegistrationSlurm + + +def apply_registration(input_path, output_path, transformation_file, + interpolation='nearest', output_format='tif', result_dtype='unsigned char', + target='local'): + task = ApplyRegistrationSlurm if target == 'slurm' else ApplyRegistrationLocal + assert result_dtype in task.result_types + assert interpolation in task.interpolation_modes + + rand_id = hash(random.uniform(0, 1000000)) + tmp_folder = 'tmp_%i' % rand_id + config_dir = os.path.join(tmp_folder, 'configs') + + os.makedirs(config_dir, exist_ok=True) + + shebang = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs/platybrowser/bin/python' + conf = task.default_global_config() + conf.update({'shebang': shebang}) + with open(os.path.join(config_dir, 'global.config'), 'w') as f: + json.dump(conf, f) + + task_config = task.default_task_config() + task_config.update({'mem_limit': 16, 'time_limit': 240, 'threads_per_job': 4, + 'ResultImagePixelType': result_dtype}) + with open(os.path.join(config_dir, 'apply_registration.config'), 'w') as f: + json.dump(task_config, f) + + in_file = os.path.join(tmp_folder, 'inputs.json') + with open(in_file, 'w') as f: + json.dump([input_path], f) + + out_file = os.path.join(tmp_folder, 'outputs.json') + with open(out_file, 'w') as f: + json.dump([output_path], f) + + t = task(tmp_folder=tmp_folder, config_dir=config_dir, max_jobs=1, + input_path_file=in_file, output_path_file=out_file, + transformation_file=transformation_file, output_format=output_format, + interpolation=interpolation) + ret = luigi.build([t], local_scheduler=True) + + if not ret: + raise RuntimeError("Apply registration failed") + + rmtree(tmp_folder) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Apply registration to input tif file') + parser.add_argument('input_path', type=str, + help="Path to input image volume. Must be tiff and have resolution information.") + parser.add_argument('output_path', type=str, help="Path to output.") + parser.add_argument('transformation_file', type=str, help="Path to transformation to apply.") + parser.add_argument('--interpolation', type=str, default='nearest', + help="Interpolation order that will be used. Can be 'nearest' or 'linear'.") + parser.add_argument('--output_format', type=str, default='tif', + help="Output file format. Can be 'tif' or 'xml'.") + parser.add_argument('--result_dtype', type=str, default='unsigned char', + help="Image datatype. Can be 'unsigned char' or 'unsigned short'.") + parser.add_argument('--target', type=str, default='local', + help="Where to run the computation. Can be 'local' or 'slurm'.") + + args = parser.parse_args() + apply_registration(args.input_path, args.output_path, args.transformation_file, + args.interpolation, args.output_format, args.result_dtype, + args.target) diff --git a/scripts/extension/registration/apply_registration.py b/scripts/extension/registration/apply_registration.py index 051bcab9dd2d8bb368242ec0015a582bf16cef62..8dd8b684a09ddaf61db095cf4c860f1092ff1832 100644 --- a/scripts/extension/registration/apply_registration.py +++ b/scripts/extension/registration/apply_registration.py @@ -3,6 +3,7 @@ import os import json import sys +from glob import glob from subprocess import check_output, CalledProcessError import luigi @@ -19,6 +20,11 @@ class ApplyRegistrationBase(luigi.Task): default_elastix = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/software/elastix_v4.8' formats = ('bdv', 'tif') + # what about cubic etc? + interpolation_modes = {'linear': 'FinalLinearInterpolator', + 'nearest': 'FinalNearestNeighborInterpolator'} + result_types = ('unsigned char', 'unsigned short') + task_name = 'apply_registration' src_file = os.path.abspath(__file__) allow_retry = False @@ -26,6 +32,7 @@ class ApplyRegistrationBase(luigi.Task): input_path_file = luigi.Parameter() output_path_file = luigi.Parameter() transformation_file = luigi.Parameter() + interpolation = luigi.Parameter(default='nearest') output_format = luigi.Parameter(default='bdv') fiji_executable = luigi.Parameter(default=default_fiji) elastix_directory = luigi.Parameter(default=default_elastix) @@ -34,6 +41,52 @@ class ApplyRegistrationBase(luigi.Task): def requires(self): return self.dependency + @staticmethod + def default_task_config(): + config = LocalTask.default_task_config() + config.update({'ResultImagePixelType': None}) + return config + + # update the transformation with our interpolation mode + # and the corresponding dtype + def update_transformation(self, in_file, out_file, res_type): + + interpolator_name = self.interpolation_modes[self.interpolation] + + def update_line(line, to_write): + line = line.rstrip('\n') + line = line.split() + line[1] = "\"%s\")" % to_write + line = " ".join(line) + "\n" + return line + + with open(in_file, 'r') as f_in, open(out_file, 'w') as f_out: + for line in f_in: + # change the interpolator + if line.startswith("(ResampleInterpolator"): + line = update_line(line, interpolator_name) + # change the pixel result type + elif line.startswith("(ResultImagePixelType") and res_type is not None: + line = update_line(line, res_type) + f_out.write(line) + + def update_transformations(self, res_type): + trafo_folder, trafo_name = os.path.split(self.transformation_file) + assert trafo_name.startswith('TransformParameters') + trafo_files = glob(os.path.join(trafo_folder, 'TransformParameters*')) + + out_folder = os.path.join(self.tmp_folder, 'transformations') + os.makedirs(out_folder, exist_ok=True) + + for trafo in trafo_files: + name = os.path.split(trafo)[1] + out = os.path.join(out_folder, name) + self.update_transformation(trafo, out, res_type) + + new_trafo = os.path.join(out_folder, trafo_name) + assert os.path.exists(new_trafo) + return new_trafo + def run_impl(self): # get the global config and init configs shebang = self.global_config_values()[0] @@ -51,18 +104,27 @@ class ApplyRegistrationBase(luigi.Task): assert os.path.exists(self.transformation_file) assert os.path.exists(self.fiji_executable) assert os.path.exists(self.elastix_directory) - assert self.output_format in (self.formats) + assert self.output_format in self.formats + assert self.interpolation in self.interpolation_modes + + config = self.get_task_config() + res_type = config.pop('res_type', None) + # TODO what are valid res types? + if res_type is not None: + assert res_type in self.result_types + trafo_file = self.update_transformations(res_type) # get the split of file-ids to the volume file_list = vu.blocks_in_volume((n_files,), (1,)) # we don't need any additional config besides the paths - config = {"input_path_file": self.input_path_file, - "output_path_file": self.output_path_file, - "transformation_file": self.transformation_file, - "fiji_executable": self.fiji_executable, - "elastix_directory": self.elastix_directory, - "tmp_folder": self.tmp_folder, 'output_format': self.output_format} + config.update({"input_path_file": self.input_path_file, + "output_path_file": self.output_path_file, + "transformation_file": trafo_file, + "fiji_executable": self.fiji_executable, + "elastix_directory": self.elastix_directory, + "tmp_folder": self.tmp_folder, + "output_format": self.output_format}) # prime and run the jobs n_jobs = min(self.max_jobs, n_files) @@ -108,7 +170,6 @@ def apply_for_file(input_path, output_path, assert os.path.exists(tmp_folder) assert os.path.exists(input_path) assert os.path.exists(transformation_file) - assert os.path.exists(os.path.split(output_path)[0]) if output_format == 'tif': format_str = 'Save as Tiff' @@ -117,14 +178,15 @@ def apply_for_file(input_path, output_path, else: assert False, "Invalid output format %s" % output_format + trafo_dir, trafo_name = os.path.split(transformation_file) # transformix arguments need to be passed as one string, # with individual arguments comma separated # the argument to transformaix needs to be one large comma separated string transformix_argument = ["elastixDirectory=\'%s\'" % elastix_directory, "workingDirectory=\'%s\'" % os.path.abspath(tmp_folder), - "inputImageFile=\'%s\'" % input_path, - "transformationFile=\'%s\'" % transformation_file, - "outputFile=\'%s\'" % output_path, + "inputImageFile=\'%s\'" % os.path.abspath(input_path), + "transformationFile=\'%s\'" % trafo_name, + "outputFile=\'%s\'" % os.path.abspath(output_path), "outputModality=\'%s\'" % format_str, "numThreads=\'%i\'" % n_threads] transformix_argument = ",".join(transformix_argument) @@ -164,6 +226,15 @@ def apply_for_file(input_path, output_path, fu.log("Go back to cwd: %s" % cwd) os.chdir(cwd) + if output_format == 'tif': + expected_output = output_path + '-ch0.tif' + elif output_format == 'bdv': + expected_output = output_path + '.xml' + + # the elastix plugin has the nasty habit of failing without throwing a proper error code. + # so we check here that we actually have the expected output. if not, something went wrong. + assert os.path.exists(expected_output), "The output %s is not there." % expected_output + def apply_registration(job_id, config_path): fu.log("start processing job %i" % job_id) diff --git a/test/registration/check_wrapper.py b/test/registration/check_wrapper.py index 5a381c33185cdc9838bffae8e9d949b6f53a4856..0f2eb029d9b0d08410e1341edfbc7ea2eb2ecf95 100644 --- a/test/registration/check_wrapper.py +++ b/test/registration/check_wrapper.py @@ -38,8 +38,10 @@ def check_wrapper(): # For now, we use the similarity trafo to save time trafo = os.path.join(trafo_dir, 'TransformParameters.Similarity-3Channels.0.txt') + interpolation = 'nearest' t = task(tmp_folder=tmp_folder, config_dir=conf_dir, max_jobs=1, - input_path_file=in_file, output_path_file=out_file, transformation_file=trafo) + input_path_file=in_file, output_path_file=out_file, transformation_file=trafo, + interpolation=interpolation) ret = luigi.build([t], local_scheduler=True) assert ret expected_xml = out_path + '.xml' diff --git a/update_registration.py b/update_registration.py index 4d779667dbfc7878ae17ef06f790693ee285f46f..3cd0870d192a3c63eb7db04b5a4d6fc8e9d19942 100755 --- a/update_registration.py +++ b/update_registration.py @@ -127,9 +127,13 @@ def apply_registration(input_folder, new_folder, with open(output_file, 'w') as f: json.dump(outputs, f) + # once we have other sources that are registered to the em space, + # we should expose the interpolation mode + interpolation = 'nearest' t = task(tmp_folder=tmp_folder, config_dir=config_dir, max_jobs=max_jobs, input_path_file=input_file, output_path_file=output_file, - transformation_file=transformation_file, output_format='tif') + transformation_file=transformation_file, output_format='tif', + interpolation=interpolation) ret = luigi.build([t], local_scheduler=True) if not ret: raise RuntimeError("Registration failed")