From dd0aa99eae6075035974a96d95e98dc7442f2447 Mon Sep 17 00:00:00 2001
From: Constantin Pape <c.pape@gmx.net>
Date: Mon, 16 Sep 2019 21:42:28 +0200
Subject: [PATCH] Add python scripts to run registration on the cluster

---
 scripts/extension/registration/__init__.py    |   1 +
 .../registration/apply_registration.py        | 170 ++++++++++++++++++
 update_regestration.py                        | 124 +++++++++++++
 3 files changed, 295 insertions(+)
 create mode 100644 scripts/extension/registration/__init__.py
 create mode 100644 scripts/extension/registration/apply_registration.py
 create mode 100755 update_regestration.py

diff --git a/scripts/extension/registration/__init__.py b/scripts/extension/registration/__init__.py
new file mode 100644
index 0000000..b2cb180
--- /dev/null
+++ b/scripts/extension/registration/__init__.py
@@ -0,0 +1 @@
+from apply_registration import ApplyRegistrationLocal, ApplyRegistrationSlurm
diff --git a/scripts/extension/registration/apply_registration.py b/scripts/extension/registration/apply_registration.py
new file mode 100644
index 0000000..435a45a
--- /dev/null
+++ b/scripts/extension/registration/apply_registration.py
@@ -0,0 +1,170 @@
+#! /usr/bin/python
+
+import os
+import json
+import sys
+from subprocess import check_output, CalledProcessError
+
+import luigi
+import cluster_tools.utils.function_utils as fu
+import cluster_tools.utils.volume_utils as vu
+from cluster_tools.task_utils import DummyTask
+from cluster_tools.cluster_tasks import SlurmTask, LocalTask, LSFTask
+
+
+class ApplyRegistrationBase(luigi.Task):
+    """ ApplyRegistration base class
+    """
+
+    task_name = 'apply_registration'
+    src_file = os.path.abspath(__file__)
+    allow_retry = False
+
+    input_path_file = luigi.Parameter()
+    output_path_file = luigi.Parameter()
+    transformation_file = luigi.Parameter()
+    fiji_executable = luigi.Parameter(default='/g/almf/software/Fiji.app/ImageJ-linux64')
+    elastix_directory = luigi.Parameter(default='/g/almf/software/elastix_v4.8')
+    dependency = luigi.TaskParameter(default=DummyTask())
+
+    def requires(self):
+        return self.dependency
+
+    def run_impl(self):
+        # get the global config and init configs
+        shebang = self.global_config_values()[0]
+        self.init(shebang)
+
+        with open(self.input_path_file) as f:
+            inputs = json.load(f)
+        with open(self.output_path_file) as f:
+            outputs = json.load(f)
+
+        assert len(inputs) == len(outputs)
+        assert all(os.path.exists(inp) for inp in inputs)
+        n_files = len(inputs)
+
+        assert os.path.exists(self.transformation_file)
+        assert os.path.exists(self.fiji_executable)
+        assert os.path.exists(self.elastix_directory)
+
+        # get the split of file-ids to the volume
+        file_list = vu.blocks_in_volume((n_files,), (1,))
+
+        # we don't need any additional config besides the paths
+        config = {"input_path_file": self.input_path_file,
+                  "output_path_file": self.output_path_file,
+                  "transformation_file": self.transformation_file,
+                  "fiji_executable": self.fiji_executable,
+                  "elastix_directory": self.elastix_directory,
+                  "tmp_folder": self.tmp_folder}
+
+        # prime and run the jobs
+        n_jobs = min(self.max_jobs, n_files)
+        self.prepare_jobs(n_jobs, file_list, config)
+        self.submit_jobs(n_jobs)
+
+        # wait till jobs finish and check for job success
+        self.wait_for_jobs()
+        self.check_jobs(n_jobs)
+
+
+class ApplyRegistrationLocal(ApplyRegistrationBase, LocalTask):
+    """
+    ApplyRegistration on local machine
+    """
+    pass
+
+
+class ApplyRegistrationSlurm(ApplyRegistrationBase, SlurmTask):
+    """
+    ApplyRegistration on slurm cluster
+    """
+    pass
+
+
+class ApplyRegistrationLSF(ApplyRegistrationBase, LSFTask):
+    """
+    ApplyRegistration on lsf cluster
+    """
+    pass
+
+
+#
+# Implementation
+#
+
+def apply_for_file(input_path, output_path,
+                   transformation_file, fiji_executable,
+                   elastix_directory, tmp_folder, n_threads):
+
+    # command based on https://github.com/embl-cba/fiji-plugin-elastixWrapper/issues/2:
+    # srun --mem 16000 -n 1 -N 1 -c 8 -t 30:00 -o $OUT -e $ERR
+    # /g/almf/software/Fiji.app/ImageJ-linux64  --ij2 --headless --run "Transformix"
+    # "elastixDirectory='/g/almf/software/elastix_v4.8', workingDirectory='$TMPDIR',
+    # inputImageFile='$INPUT_IMAGE',transformationFile='/g/cba/exchange/platy-trafos/linear/TransformParameters.BSpline10-3Channels.0.txt
+    # outputFile='$OUTPUT_IMAGE',outputModality='Save as BigDataViewer .xml/.h5',numThreads='1'"
+    cmd = [fiji_executable, "-ij2", "--headless", "--run", "Transformix",
+           "elastix_directory=%s" % elastix_directory,
+           "workingDirectory=%s" % tmp_folder,
+           "inputImageFile=%s" % input_path,
+           "transformationFile=%s" % transformation_file,
+           "outputFile=%s" % output_path,
+           "outputModality=\'Save as BigDataViewer .xml/.h5\'",
+           "numThreads=1"]  # TODO why do we use numThreads=1 and not the same as -c in the slurm command?
+
+    try:
+        check_output(cmd)
+    except CalledProcessError as e:
+        raise RuntimeError(e.output)
+
+
+def apply_registration(job_id, config_path):
+    fu.log("start processing job %i" % job_id)
+    fu.log("reading config from %s" % config_path)
+
+    # read the config
+    with open(config_path) as f:
+        config = json.load(f)
+
+    # get list of the input and output paths
+    input_file = config['input_path_file']
+    with open(input_file) as f:
+        inputs = json.load(f)
+    output_file = config['output_path_file']
+    with open(output_file) as f:
+        outputs = json.load(f)
+
+    transformation_file = config['transformation_file']
+    fiji_executable = config['fiji_executable']
+    elastix_directory = config['elastix_directory']
+    tmp_folder = config['tmp_folder']
+
+    file_list = config['block_list']
+    n_threads = config.get('threads_per_job', 1)
+
+    fu.log("Applying registration with:")
+    fu.log("transformation_file: %s" % transformation_file)
+    fu.log("fiji_executable: %s" % fiji_executable)
+    fu.log("elastix_directory: %s" % elastix_directory)
+
+    for file_id in file_list:
+        fu.log("start processing block %i" % file_id)
+
+        infile = inputs[file_id]
+        outfile = outputs[file_id]
+        fu.log("Input: %s" % infile)
+        fu.log("Output: %s" % outfile)
+        apply_for_file(infile, outfile,
+                       transformation_file, fiji_executable,
+                       elastix_directory, tmp_folder, n_threads)
+        fu.log_block_success(file_id)
+
+    fu.log_job_success(job_id)
+
+
+if __name__ == '__main__':
+    path = sys.argv[1]
+    assert os.path.exists(path), path
+    job_id = int(os.path.split(path)[1].split('.')[0].split('_')[-1])
+    apply_registration(job_id, path)
diff --git a/update_regestration.py b/update_regestration.py
new file mode 100755
index 0000000..c27a2da
--- /dev/null
+++ b/update_regestration.py
@@ -0,0 +1,124 @@
+#! /g/arendt/pape/miniconda3/envs/platybrowser/bin/python
+
+import os
+import json
+import argparse
+from subprocess import check_output
+import luigi
+
+from scripts.files import copy_release_folder, make_folder_structure, make_bdv_server_file
+from scripts.release_helper import add_version
+from scripts.extension.registration import ApplyRegistrationLocal, ApplyRegistrationSlurm
+from scripts.default_config import get_default_shebang
+
+
+def get_tags():
+    tag = check_output(['git', 'describe', '--abbrev=0']).decode('utf-8').rstrip('\n')
+    new_tag = tag.split('.')
+    new_tag[-1] = str(int(new_tag[-1]) + 1)
+    new_tag = '.'.join(new_tag)
+    return tag, new_tag
+
+
+def apply_registration(input_folder, new_folder,
+                       transformation_file, source_prefix,
+                       target, max_jobs):
+    task = ApplyRegistrationSlurm if target == 'slurm' else ApplyRegistrationLocal
+    tmp_folder = './tmp_registration'
+    os.makedirs(tmp_folder, exist_ok=True)
+
+    # find all input files
+    names = os.listdir(input_folder)
+    inputs = [os.path.join(input_folder, name) for name in names]
+
+    if len(inputs) == 0:
+        raise RuntimeError("Did not find any files with prefix %s in %s" % (source_prefix,
+                                                                            input_folder))
+
+    output_folder = os.path.join(new_folder, 'images')
+    # TODO parse names to get output names
+    output_names = []
+    outputs = [os.path.join(output_folder, name) for name in output_names]
+
+    # update the task config
+    config_dir = os.path.join(tmp_folder, 'config')
+    os.makedirs(config_dir, exist_ok=True)
+
+    shebang = get_default_shebang()
+    global_config = task.default_global_config()
+    global_config.update({'shebang': shebang})
+    with open(os.path.join(config_dir, 'global.config'), 'w') as f:
+        json.dump(global_config, f)
+
+    # TODO more time than 3 hrs?
+    task_config = task.default_task_config()
+    task_config.update({'mem_limit': 16, 'time_limit': 180})
+    with open(os.path.join(config_dir, 'apply_registration.config'), 'w') as f:
+        json.dump(task_config, f)
+
+    # write path name files to json
+    input_file = os.path.join(tmp_folder, 'input_files.json')
+    with open(input_file, 'w') as f:
+        json.dump(inputs, f)
+    output_file = os.path.joout(tmp_folder, 'output_files.json')
+    with open(output_file, 'w') as f:
+        json.dump(outputs, f)
+
+    t = task(tmp_folder=tmp_folder, config_dir=config_dir, max_jobs=max_jobs,
+             input_path_file=input_file, output_path_file=output_file,
+             transformation_file=transformation_file)
+
+    ret = luigi.build([t], local_scheduler=True)
+    if not ret:
+        raise RuntimeError("Registration failed")
+
+
+def update_regestration(transformation_file, input_folder, source_prefix, target, max_jobs):
+    """ Update the prospr segmentation.
+    This is a special case of 'update_patch', that applies a new prospr registration.
+
+    Arguments:
+        transformation_file [str] - path to the transformation used to register
+        input_folder [str] - folder with unregistered data
+        source_prefix [str] - prefix of the source data to apply the registration to
+        target [str] - target of computation
+        max_jobs [int] - max number of jobs for computation
+    """
+    tag, new_tag = get_tags()
+    print("Updating platy browser from", tag, "to", new_tag)
+
+    # make new folder structure
+    folder = os.path.join('data', tag)
+    new_folder = os.path.join('data', new_tag)
+    make_folder_structure(new_folder)
+
+    # copy the release folder
+    copy_release_folder(folder, new_folder)
+
+    # apply new registration to all files of the source prefix
+    apply_registration(input_folder, new_folder,
+                       transformation_file, source_prefix,
+                       target, max_jobs)
+    add_version(new_tag)
+    make_bdv_server_file(new_folder, os.path.join(new_folder, 'misc', 'bdv_server.txt'),
+                         relative_paths=True)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Update prospr registration in platy-browser-data.')
+    parser.add_argument('transformation_file', type=str, help="path to transformation file")
+
+    parser.add_argument('--input_folder', type=str, default="data/rawdata/prospr",
+                        help="Folder with (not registered) input files")
+    help_str = "Prefix for the input data. Please change this if you change the 'input_folder' from its default value"
+    parser.add_argument('--source_prefix', type=str, default="prospr-6dpf-1-whole",
+                        help=help_str)
+
+    parser.add_argument('--target', type=str, default='slurm',
+                        help="Computatin plaform, can be 'slurm' or 'local'")
+    parser.add_argument('--max_jobs', type=int, default=100,
+                        help="Maximal number of jobs used for computation")
+
+    args = parser.parse_args()
+    update_regestration(args.transformation_file, args.input_folder, args.source_prefix,
+                        args.target, args.max_jobs)
-- 
GitLab