From 4816b25e441c33dce1782e9591ac2ee599416618 Mon Sep 17 00:00:00 2001
From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de>
Date: Thu, 27 Jun 2019 18:23:58 +0200
Subject: [PATCH] Add segmented prospr regions to region attributes and adapt
 muscle mapping

---
 scripts/attributes/master.py            |  4 +-
 scripts/attributes/region_attributes.py | 72 ++++++++++++++++++++-----
 scripts/attributes/util.py              | 31 +++++++++--
 scripts/export/to_bdv.py                |  8 +--
 test/attributes/test_regions.py         |  8 +--
 5 files changed, 99 insertions(+), 24 deletions(-)

diff --git a/scripts/attributes/master.py b/scripts/attributes/master.py
index d25dffc..3c2d477 100644
--- a/scripts/attributes/master.py
+++ b/scripts/attributes/master.py
@@ -65,8 +65,10 @@ def make_cell_tables(folder, name, tmp_folder, resolution,
     region_out = os.path.join(table_folder, 'regions.csv')
     # need to make sure the inputs are copied / updated in
     # the segmentation folder beforehand
+    image_folder = os.path.join(folder, 'images')
     segmentation_folder = os.path.join(folder, 'segmentations')
-    region_attributes(seg_path, region_out, segmentation_folder,
+    region_attributes(seg_path, region_out,
+                      image_folder, segmentation_folder,
                       label_ids, tmp_folder, target, max_jobs)
 
 
diff --git a/scripts/attributes/region_attributes.py b/scripts/attributes/region_attributes.py
index 23e5c07..67f9635 100644
--- a/scripts/attributes/region_attributes.py
+++ b/scripts/attributes/region_attributes.py
@@ -1,8 +1,9 @@
 import os
+import glob
 import numpy as np
 import h5py
 
-from .util import write_csv, node_labels
+from .util import write_csv, node_labels, normalize_overlap_dict
 from ..files import get_h5_path_from_xml
 
 
@@ -35,18 +36,48 @@ def write_region_table(label_ids, label_list, semantic_mapping_list, out_path):
     write_csv(out_path, table, col_names)
 
 
-def region_attributes(seg_path, region_out, segmentation_folder,
+def muscle_attributes(muscle_path, key_muscle,
+                      seg_path, key_seg,
+                      tmp_folder, target, max_jobs):
+    muscle_labels = node_labels(seg_path, key_seg,
+                                muscle_path, key_muscle,
+                                'muscle', tmp_folder,
+                                target, max_jobs, max_overlap=False)
+
+    foreground_id = 255
+
+    # we count everything that has at least 25 % overlap as muscle
+    overlap_threshold = .25
+    muscle_labels = normalize_overlap_dict(muscle_labels)
+    label_ids = [k for k in sorted(muscle_labels.keys())]
+    muscle_labels = np.array([foreground_id if muscle_labels[label_id].get(foreground_id, 0) > overlap_threshold
+                              else 0 for label_id in label_ids])
+
+    # print()
+    # print()
+    # print(np.sum(muscle_labels == foreground_id))
+    # print(muscle_labels.size)
+    # print()
+    # print()
+
+    semantic_muscle = {'muscle': [foreground_id]}
+    return muscle_labels, semantic_muscle
+
+
+def region_attributes(seg_path, region_out,
+                      image_folder, segmentation_folder,
                       label_ids, tmp_folder, target, max_jobs):
 
-    key0 = 't00000/s00/0/cells'
+    key_seg = 't00000/s00/2/cells'
+    key_tissue = 't00000/s00/0/cells'
 
     # 1.) compute the mapping to carved regions
     #
     carved_path = os.path.join(segmentation_folder,
                                'sbem-6dpf-1-whole-segmented-tissue-labels.xml')
     carved_path = get_h5_path_from_xml(carved_path, return_absolute_path=True)
-    carved_labels = node_labels(seg_path, key0,
-                                carved_path, key0,
+    carved_labels = node_labels(seg_path, key_seg,
+                                carved_path, key_tissue,
                                 'carved-regions', tmp_folder,
                                 target, max_jobs)
     # load the mapping of ids to semantics
@@ -56,16 +87,31 @@ def region_attributes(seg_path, region_out, segmentation_folder,
     semantics_to_carved_ids = {name: idx.tolist()
                                for name, idx in zip(names, ids)}
 
+    label_list = [carved_labels]
+    semantic_mapping_list = [semantics_to_carved_ids]
+
     # 2.) compute the mapping to muscles
     muscle_path = os.path.join(segmentation_folder, 'sbem-6dpf-1-whole-segmented-muscles.xml')
     muscle_path = get_h5_path_from_xml(muscle_path, return_absolute_path=True)
-    muscle_labels = node_labels(seg_path, key0,
-                                muscle_path, key0,
-                                'muscle', tmp_folder,
-                                target, max_jobs)
-    semantic_muscle = {'muscle': [255]}
+    # need to be more lenient with the overlap criterion for the muscle mapping
+    muscle_labels, semantic_muscle = muscle_attributes(muscle_path, key_tissue,
+                                                       seg_path, key_seg,
+                                                       tmp_folder, target, max_jobs)
+    label_list.append(muscle_labels)
+    semantic_mapping_list.append(semantic_muscle)
+
+    # map all the segmented prospr regions
+    region_paths = glob.glob(os.path.join(image_folder, "prospr-6dpf-1-whole-segmented-*"))
+    region_names = [os.path.splitext(pp.split('-')[-1])[0].lower() for pp in region_paths]
+    region_paths = [get_h5_path_from_xml(rp, return_absolute_path=True)
+                    for rp in region_paths]
+    for rpath, rname in zip(region_paths, region_names):
+        rlabels = node_labels(seg_path, key_seg,
+                              rpath, key_tissue,
+                              rname, tmp_folder,
+                              target, max_jobs)
+        label_list.append(rlabels)
+        semantic_mapping_list.append({rname: [255]})
 
     # 3.) merge the mappings and write new table
-    write_region_table(label_ids, [carved_labels, muscle_labels],
-                       [semantics_to_carved_ids, semantic_muscle],
-                       region_out)
+    write_region_table(label_ids, label_list, semantic_mapping_list, region_out)
diff --git a/scripts/attributes/util.py b/scripts/attributes/util.py
index 2b9ae18..d1524b0 100644
--- a/scripts/attributes/util.py
+++ b/scripts/attributes/util.py
@@ -2,6 +2,7 @@ import os
 import csv
 import luigi
 import z5py
+import nifty.distributed as ndist
 
 from cluster_tools.node_labels import NodeLabelWorkflow
 
@@ -15,9 +16,19 @@ def write_csv(output_path, data, col_names):
         writer.writerows(data)
 
 
+def normalize_overlap_dict(label_overlap_dict):
+    sums = {label_id: sum(overlaps.values())
+            for label_id, overlaps in label_overlap_dict.items()}
+    label_overlap_dict = {label_id: {ovlp_id: count / sums[label_id]
+                                     for ovlp_id, count in overlaps.items()}
+                          for label_id, overlaps in label_overlap_dict.items()}
+    return label_overlap_dict
+
+
 def node_labels(seg_path, seg_key,
                 input_path, input_key, prefix,
-                tmp_folder, target, max_jobs):
+                tmp_folder, target, max_jobs,
+                max_overlap=True):
     task = NodeLabelWorkflow
     config_folder = os.path.join(tmp_folder, 'configs')
 
@@ -29,11 +40,23 @@ def node_labels(seg_path, seg_key,
              ws_path=seg_path, ws_key=seg_key,
              input_path=input_path, input_key=input_key,
              output_path=out_path, output_key=out_key,
-             prefix=prefix, )
+             prefix=prefix, max_overlap=max_overlap)
     ret = luigi.build([t], local_scheduler=True)
     if not ret:
         raise RuntimeError("Node labels for %s" % prefix)
 
-    with z5py.File(out_path, 'r') as f:
-        data = f[out_key][:]
+    f = z5py.File(out_path, 'r')
+    ds_out = f[out_key]
+
+    if max_overlap:
+        data = ds_out[:]
+    else:
+        n_chunks = ds_out.number_of_chunks
+        out_path = ds_out.path
+        data = [ndist.deserializeOverlapChunk(out_path, (chunk_id,))[0]
+                for chunk_id in range(n_chunks)]
+        data = {label_id: overlaps
+                for chunk_data in data
+                for label_id, overlaps in chunk_data.items()}
+
     return data
diff --git a/scripts/export/to_bdv.py b/scripts/export/to_bdv.py
index 3df04d5..9c98032 100644
--- a/scripts/export/to_bdv.py
+++ b/scripts/export/to_bdv.py
@@ -49,7 +49,9 @@ def to_bdv(in_path, in_key, out_path, resolution, tmp_folder, target='slurm'):
     if not ret:
         raise RuntimeError("Segmentation export failed")
 
-    # write the max-id
+    # write the max-id for all datasets
     with h5py.File(out_path) as f:
-        ds = f['t00000/s00/0/cells']
-        ds.attrs['maxId'] = max_id
+        g = f['t00000/s00']
+        for scale_group in g.items():
+            ds = scale_group['cells']
+            ds.attrs['maxId'] = max_id
diff --git a/test/attributes/test_regions.py b/test/attributes/test_regions.py
index 66209d6..7ac4378 100644
--- a/test/attributes/test_regions.py
+++ b/test/attributes/test_regions.py
@@ -15,7 +15,7 @@ sys.path.append('../..')
 class TestRegions(unittest.TestCase):
     tmp_folder = 'tmp'
 
-    def tearDown(self):
+    def _tearDown(self):
         try:
             rmtree(self.tmp_folder)
         except OSError:
@@ -61,6 +61,7 @@ class TestRegions(unittest.TestCase):
     def test_regions(self):
         from scripts.attributes.region_attributes import region_attributes
 
+        image_folder = '../../data/0.0.0/images'
         segmentation_folder = '../../data/0.0.0/segmentations'
         seg_path = os.path.join(segmentation_folder,
                                 'sbem-6dpf-1-whole-segmented-cells-labels.h5')
@@ -81,8 +82,9 @@ class TestRegions(unittest.TestCase):
             json.dump(conf, f)
 
         target = 'local'
-        max_jobs = 32
-        region_attributes(seg_path, output_path, segmentation_folder,
+        max_jobs = 8
+        region_attributes(seg_path, output_path,
+                          image_folder, segmentation_folder,
                           label_ids, self.tmp_folder, target, max_jobs)
 
         table = pandas.read_csv(output_path, sep='\t')
-- 
GitLab