From 1c8e755e31bf93701460b2a95e29881d9184a1d2 Mon Sep 17 00:00:00 2001 From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de> Date: Sun, 30 Jun 2019 13:12:04 +0200 Subject: [PATCH] Implement anchor correction --- scripts/attributes/base_attributes.py | 82 ++++++++++++++++++++------- test/attributes/test_base.py | 23 +++++--- 2 files changed, 77 insertions(+), 28 deletions(-) diff --git a/scripts/attributes/base_attributes.py b/scripts/attributes/base_attributes.py index 11cda4b..24e3c10 100644 --- a/scripts/attributes/base_attributes.py +++ b/scripts/attributes/base_attributes.py @@ -1,11 +1,12 @@ import os import json +import h5py import z5py import numpy as np import luigi from cluster_tools.morphology import MorphologyWorkflow -from cluster_tools.morphology import CorrectAnchorsWorkflow +from cluster_tools.morphology import RegionCentersWorkflow from .util import write_csv @@ -40,28 +41,55 @@ def n5_attributes(input_path, input_key, tmp_folder, target, max_jobs): return out_path, out_key -# correct anchor points that are not inside of objects: -# for each object, check if the anchor point is in the object -# if it is NOT, do: -# - for all chunks that overlap with the bounding box of the object: -# - check if object is in chunk -# - if it is: set anchor to eccentricity center of object in chunk -# - else: continue -def run_correction(input_path, input_key, attr_path, attr_key, +# set the anchor to region center (= maximum of boundary distance transform +# inside the object) instead of com +def run_correction(input_path, input_key, tmp_folder, target, max_jobs): - task = CorrectAnchorsWorkflow + task = RegionCentersWorkflow config_folder = os.path.join(tmp_folder, 'configs') + out_path = os.path.join(tmp_folder, 'data.n5') + out_key = 'region_centers' + + # we need to run this at a lower scale, as a heuristic, + # we take the first scale with all dimensions < 1750 pix + # (corresponds to scale 4 in sbem) + max_dim_size = 1750 + scale_key = input_key + with h5py.File(input_path, 'r') as f: + while True: + shape = f[scale_key].shape + if all(sh <= max_dim_size for sh in shape): + break + + scale = int(scale_key.split('/')[2]) + 1 + next_scale_key = 't00000/s00/%i/cells' % scale + if next_scale_key not in f: + break + scale_key = next_scale_key + + with h5py.File(input_path, 'r') as f: + shape1 = f[input_key].shape + shape2 = f[scale_key].shape + scale_factor = np.array([float(sh1) / sh2 for sh1, sh2 in zip(shape1, shape2)]) + t = task(tmp_folder=tmp_folder, config_dir=config_folder, max_jobs=max_jobs, target=target, - input_path=input_path, input_key=input_key, - morphology_path=attr_path, morphology_key=attr_key) + input_path=input_path, input_key=scale_key, + output_path=out_path, output_key=out_key, + ignore_label=0) ret = luigi.build([t], local_scheduler=True) if not ret: raise RuntimeError("Anchor correction failed") + with z5py.File(out_path, 'r') as f: + anchors = f[out_key][:] + anchors *= scale_factor + return anchors + -def to_csv(input_path, input_key, output_path, resolution): +def to_csv(input_path, input_key, output_path, resolution, + anchors=None): # load the attributes from n5 with z5py.File(input_path, 'r') as f: attributes = f[input_key][:] @@ -86,15 +114,27 @@ def to_csv(input_path, input_key, output_path, resolution): return coords # center of mass / anchor points - com = translate_coordinate_tuple(attributes[:, 2:5]) + com = attributes[:, 2:5] + if anchors is None: + anchors = translate_coordinate_tuple(com) + else: + assert len(anchors) == len(com) + assert anchors.shape[1] == 3 + + # some of the corrected anchors might not be present, + # so we merge them with the com here + invalid_anchors = np.isclose(anchors, 0.).all(axis=1) + anchors[invalid_anchors] = com[invalid_anchors] + anchors = translate_coordinate_tuple(anchors) + # attributes[5:8] = min coordinate of bounding box minc = translate_coordinate_tuple(attributes[:, 5:8]) - # attributes[8:] = min coordinate of bounding box + # attributes[8:11] = min coordinate of bounding box maxc = translate_coordinate_tuple(attributes[:, 8:11]) # NOTE: attributes[1] = size in pixel # make the output attributes - data = np.concatenate([label_ids, com, minc, maxc, attributes[:, 1:2]], axis=1) + data = np.concatenate([label_ids, anchors, minc, maxc, attributes[:, 1:2]], axis=1) write_csv(output_path, data, col_names) @@ -110,13 +150,13 @@ def base_attributes(input_path, input_key, output_path, resolution, # correct anchor positions if correct_anchors: - pass - # TODO need to test this first - # run_correction(input_path, input_key, tmp_path, tmp_key, - # tmp_folder, target, max_jobs) + anchors = run_correction(input_path, input_key, + tmp_folder, target, max_jobs) + else: + anchors = None # write output to csv - to_csv(tmp_path, tmp_key, output_path, resolution) + to_csv(tmp_path, tmp_key, output_path, resolution, anchors) # load and return label_ids with z5py.File(tmp_path, 'r') as f: diff --git a/test/attributes/test_base.py b/test/attributes/test_base.py index 7d3b4e1..53a077a 100644 --- a/test/attributes/test_base.py +++ b/test/attributes/test_base.py @@ -7,11 +7,11 @@ from shutil import rmtree sys.path.append('../..') -# check new version of gene mapping against original -class TestGeneAttributes(unittest.TestCase): +# check the basic / default attributes +class TestBaseAttributes(unittest.TestCase): tmp_folder = 'tmp' - def tearDown(self): + def _tearDown(self): try: rmtree(self.tmp_folder) except OSError: @@ -27,7 +27,7 @@ class TestGeneAttributes(unittest.TestCase): max_jobs = 32 resolution = [0.1, 0.08, 0.08] base_attributes(input_path, input_key, output_path, resolution, - self.tmp_folder, target, max_jobs, correct_anchors=False) + self.tmp_folder, target, max_jobs, correct_anchors=True) table = pandas.read_csv(output_path, sep='\t') print("Checking attributes ...") @@ -45,10 +45,19 @@ class TestGeneAttributes(unittest.TestCase): bb = tuple(slice(int(min_ / res) - 1, int(max_ / res) + 2) for min_, max_, res in zip(bb_min, bb_max, resolution)) seg = ds[bb] + shape = seg.shape - # TODO check anchor once we have anchor correction - # anchor = [row.anchor_z, row.anchor_y, row.anchor_x] - # anchor = tuple(anch // res - b.start for anch, res, b in zip(anchor, resolution, bb)) + anchor = [row.anchor_z, row.anchor_y, row.anchor_x] + anchor = tuple(int(anch // res - b.start) + for anch, res, b in zip(anchor, resolution, bb)) + + # TODO this check still fails for a few ids. I am not sure if this is a systematic problem + # or just some rounding inaccuracies + # we need to give some tolerance here + anchor_bb = tuple(slice(max(0, an - 2), min(an + 2, sh)) for an, sh in zip(anchor, shape)) + sub = seg[anchor_bb] + print(label_id) + self.assertTrue(label_id in sub) # anchor_id = seg[anchor] # self.assertEqual(anchor_id, label_id) -- GitLab