From a92a627c764f05c85bb01a9e60c337924b426c25 Mon Sep 17 00:00:00 2001 From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de> Date: Thu, 26 Sep 2019 15:10:25 +0200 Subject: [PATCH] Implement proper registration wrapper test and fix elastix config parser issue --- .../registration/apply_registration.py | 5 +- test/registration/check_wrapper.py | 54 --------- test/registration/test_wrapper.py | 107 ++++++++++++++++++ 3 files changed, 109 insertions(+), 57 deletions(-) delete mode 100644 test/registration/check_wrapper.py create mode 100644 test/registration/test_wrapper.py diff --git a/scripts/extension/registration/apply_registration.py b/scripts/extension/registration/apply_registration.py index 8dd8b68..6116b75 100644 --- a/scripts/extension/registration/apply_registration.py +++ b/scripts/extension/registration/apply_registration.py @@ -56,7 +56,7 @@ class ApplyRegistrationBase(luigi.Task): def update_line(line, to_write): line = line.rstrip('\n') line = line.split() - line[1] = "\"%s\")" % to_write + line = [line[0], "\"%s\")" % to_write] line = " ".join(line) + "\n" return line @@ -108,8 +108,7 @@ class ApplyRegistrationBase(luigi.Task): assert self.interpolation in self.interpolation_modes config = self.get_task_config() - res_type = config.pop('res_type', None) - # TODO what are valid res types? + res_type = config.pop('ResultImagePixelType', None) if res_type is not None: assert res_type in self.result_types trafo_file = self.update_transformations(res_type) diff --git a/test/registration/check_wrapper.py b/test/registration/check_wrapper.py deleted file mode 100644 index 0f2eb02..0000000 --- a/test/registration/check_wrapper.py +++ /dev/null @@ -1,54 +0,0 @@ -import os -import json - -import luigi -from scripts.extension.registration import ApplyRegistrationLocal - - -def check_wrapper(): - in_path = '/g/kreshuk/pape/Work/my_projects/platy-browser-data/registration/9.9.9/ProSPr/stomach.tif' - - tmp_folder = os.path.abspath('tmp_registration') - out_path = os.path.join(tmp_folder, 'stomach_prospr_registered') - - in_list = [in_path] - out_list = [out_path] - in_file = './in_list.json' - with open(in_file, 'w') as f: - json.dump(in_list, f) - out_file = './out_list.json' - with open(out_file, 'w') as f: - json.dump(out_list, f) - - task = ApplyRegistrationLocal - conf_dir = './configs' - os.makedirs(conf_dir, exist_ok=True) - - global_conf = task.default_global_config() - shebang = '/g/kreshuk/pape/Work/software/conda/miniconda3/envs/cluster_env37/bin/python' - global_conf.update({'shebang': shebang}) - with open(os.path.join(conf_dir, 'global.config'), 'w') as f: - json.dump(global_conf, f) - - trafo_dir = '/g/kreshuk/pape/Work/my_projects/platy-browser-data/registration/0.0.0/transformations' - - # This is the full transformation, but it takes a lot of time! - trafo = os.path.join(trafo_dir, 'TransformParameters.BSpline10-3Channels.0.txt') - - # For now, we use the similarity trafo to save time - trafo = os.path.join(trafo_dir, 'TransformParameters.Similarity-3Channels.0.txt') - - interpolation = 'nearest' - t = task(tmp_folder=tmp_folder, config_dir=conf_dir, max_jobs=1, - input_path_file=in_file, output_path_file=out_file, transformation_file=trafo, - interpolation=interpolation) - ret = luigi.build([t], local_scheduler=True) - assert ret - expected_xml = out_path + '.xml' - assert os.path.exists(expected_xml), expected_xml - expected_h5 = out_path + '.h5' - assert os.path.exists(expected_h5), expected_h5 - - -if __name__ == '__main__': - check_wrapper() diff --git a/test/registration/test_wrapper.py b/test/registration/test_wrapper.py new file mode 100644 index 0000000..3e958a5 --- /dev/null +++ b/test/registration/test_wrapper.py @@ -0,0 +1,107 @@ +import os +import json +import unittest + +import luigi +import numpy as np +import imageio +from shutil import rmtree +from scripts.extension.registration import ApplyRegistrationLocal + + +class TestRegistrationWrapper(unittest.TestCase): + tmp_folder = './tmp_regestration' + + def setUp(self): + os.makedirs(self.tmp_folder, exist_ok=True) + + def tearDown(self): + try: + rmtree(self.tmp_folder) + except OSError: + pass + + def _apply_registration(self, in_path, trafo_file, interpolation, file_format, dtype): + out_path = os.path.join(self.tmp_folder, 'out') + + in_list = [in_path] + out_list = [out_path] + in_file = os.path.join(self.tmp_folder, 'in_list.json') + with open(in_file, 'w') as f: + json.dump(in_list, f) + out_file = os.path.join(self.tmp_folder, 'out_list.json') + with open(out_file, 'w') as f: + json.dump(out_list, f) + + task = ApplyRegistrationLocal + conf_dir = os.path.join(self.tmp_folder, 'configs') + os.makedirs(conf_dir, exist_ok=True) + + global_conf = task.default_global_config() + shebang = os.path.join('/g/arendt/EM_6dpf_segmentation/platy-browser-data/software/conda/miniconda3/envs', + 'platybrowser/bin/python') + global_conf.update({'shebang': shebang}) + with open(os.path.join(conf_dir, 'global.config'), 'w') as f: + json.dump(global_conf, f) + + task_conf = task.default_task_config() + task_conf.update({'threads_per_job': 8, 'ResultImagePixelType': dtype}) + with open(os.path.join(conf_dir, 'apply_registration.config'), 'w') as f: + json.dump(task_conf, f) + + t = task(tmp_folder=self.tmp_folder, config_dir=conf_dir, max_jobs=1, + input_path_file=in_file, output_path_file=out_file, transformation_file=trafo_file, + interpolation=interpolation, output_format=file_format) + ret = luigi.build([t], local_scheduler=True) + self.assertTrue(ret) + return out_path + + # only makes sense for nearest neighbor interpolation + def check_result(self, in_path, res_path, check_range=False): + res = imageio.volread(res_path) + exp = imageio.volread(in_path).astype(res.dtype) + + if check_range: + min_res = res.min() + min_exp = exp.min() + self.assertEqual(min_res, min_exp) + max_res = res.max() + max_exp = exp.max() + self.assertEqual(max_res, max_exp) + else: + un_res = np.unique(res) + un_exp = np.unique(exp) + self.assertTrue(np.array_equal(un_exp, un_res)) + + def test_nearest_mask(self): + trafo_dir = '/g/kreshuk/pape/Work/my_projects/platy-browser-data/registration/0.0.0/transformations' + # This is the full transformation, but it takes a lot of time! + # trafo_file = os.path.join(trafo_dir, 'TransformParameters.BSpline10-3Channels.0.txt') + # For now, we use the similarity trafo to save time + trafo_file = os.path.join(trafo_dir, 'TransformParameters.Similarity-3Channels.0.txt') + + in_path = '/g/kreshuk/pape/Work/my_projects/platy-browser-data/registration/9.9.9/ProSPr/stomach.tif' + out_path = self._apply_registration(in_path, trafo_file, 'nearest', 'tif', 'unsigned char') + + out_path = out_path + '-ch0.tif' + self.assertTrue(os.path.exists(out_path)) + self.check_result(in_path, out_path) + + def test_nearest_seg(self): + trafo_dir = '/g/kreshuk/pape/Work/my_projects/platy-browser-data/registration/0.0.0/transformations' + # This is the full transformation, but it takes a lot of time! + # trafo_file = os.path.join(trafo_dir, 'TransformParameters.BSpline10-3Channels.0.txt') + # For now, we use the similarity trafo to save time + trafo_file = os.path.join(trafo_dir, 'TransformParameters.Similarity-3Channels.0.txt') + + in_path = '/g/kreshuk/zinchenk/cell_match/data/genes/vc_volume_prospr_space_all_vc.tif' + out_path = self._apply_registration(in_path, trafo_file, 'nearest', 'tif', 'unsigned short') + + out_path = out_path + '-ch0.tif' + self.assertTrue(os.path.exists(out_path)) + # we can only check the range for segmentations, because individual ids might be lost + self.check_result(in_path, out_path, check_range=True) + + +if __name__ == '__main__': + unittest.main() -- GitLab