From 80c1e89bfa5e068828b3ceb041ce10d89227e957 Mon Sep 17 00:00:00 2001
From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de>
Date: Wed, 5 Jun 2019 21:07:56 +0200
Subject: [PATCH] Add script for object mapping tables

---
 scripts/attributes/base_attributes.py | 17 +++++----
 scripts/attributes/map_objects.py     | 51 +++++++++++++++++++++++++++
 scripts/attributes/master.py          | 27 ++++++++------
 scripts/attributes/util.py            | 10 ++++++
 scripts/export/export_segmentation.py |  6 ++--
 5 files changed, 88 insertions(+), 23 deletions(-)
 create mode 100644 scripts/attributes/map_objects.py
 create mode 100644 scripts/attributes/util.py

diff --git a/scripts/attributes/base_attributes.py b/scripts/attributes/base_attributes.py
index 5473f66..61c3e82 100644
--- a/scripts/attributes/base_attributes.py
+++ b/scripts/attributes/base_attributes.py
@@ -2,7 +2,6 @@
 # TODO new platy-browser env
 
 import os
-import csv
 import json
 import z5py
 import numpy as np
@@ -10,6 +9,7 @@ import numpy as np
 import luigi
 from cluster_tools.morphology import MorphologyWorkflow
 from cluster_tools.morphology import CorrectAnchorsWorkflow
+from .util import write_csv
 
 
 def make_config(tmp_folder):
@@ -71,7 +71,7 @@ def to_csv(input_path, input_key, output_path, resolution):
     label_ids = attributes[:, 0:1]
 
     # the colomn names
-    col_names = ['label_ids',
+    col_names = ['label_id',
                  'anchor_x', 'anchor_y', 'anchor_z',
                  'bb_min_x', 'bb_min_y', 'bb_min_z',
                  'bb_max_x', 'bb_max_y', 'bb_max_z',
@@ -98,13 +98,7 @@ def to_csv(input_path, input_key, output_path, resolution):
     # NOTE: attributes[1] = size in pixel
     # make the output attributes
     data = np.concatenate([label_ids, com, minc, maxc, attributes[:, 1:2]], axis=1)
-    assert data.shape[1] == len(col_names)
-
-    # write to csv
-    with open(output_path, 'w', newline='') as f:
-        writer = csv.writer(f, delimiter='\t')
-        writer.writerow(col_names)
-        writer.writerows(data)
+    write_csv(output_path, data, col_names)
 
 
 def base_attributes(input_path, input_key, output_path, resolution,
@@ -126,3 +120,8 @@ def base_attributes(input_path, input_key, output_path, resolution,
 
     # write output to csv
     to_csv(tmp_path, tmp_key, output_path, resolution)
+
+    # load and return label_ids
+    with z5py.File(tmp_path, 'r') as f:
+        label_ids = f[tmp_key][:, 0]
+    return label_ids
diff --git a/scripts/attributes/map_objects.py b/scripts/attributes/map_objects.py
new file mode 100644
index 0000000..b817eef
--- /dev/null
+++ b/scripts/attributes/map_objects.py
@@ -0,0 +1,51 @@
+#! /g/kreshuk/pape/Work/software/conda/miniconda3/envs/cluster_env37/bin/python
+# TODO new platy-browser env
+
+import os
+import numpy as np
+import luigi
+import z5py
+
+from cluster_tools.node_labels import NodeLabelWorkflow
+from .util import write_csv
+
+
+def object_labels(seg_path, seg_key,
+                  input_path, input_key, prefix,
+                  tmp_folder, target, max_jobs):
+    task = NodeLabelWorkflow
+    config_folder = os.path.join(tmp_folder, 'configs')
+
+    out_path = os.path.join(tmp_folder, 'data.n5')
+    out_key = 'node_labels_%s' % prefix
+
+    t = task(tmp_folder=tmp_folder, config_dir=config_folder,
+             max_jobs=max_jobs, target=target,
+             ws_path=seg_path, ws_key=seg_key,
+             input_path=input_path, input_key=input_key,
+             output_path=out_path, output_key=out_key,
+             prefix=prefix, )
+    ret = luigi.build([t], local_scheduler=True)
+    if not ret:
+        raise RuntimeError("Node labels for %s" % prefix)
+
+    with z5py.File(out_path, 'r') as f:
+        data = f[out_key][:]
+    return data
+
+
+def map_objects(label_ids, seg_path, seg_key, map_out,
+                map_paths, map_keys, map_names,
+                tmp_folder, target, max_jobs):
+    assert len(map_paths) == len(map_keys) == len(map_names)
+
+    data = []
+    for map_path, map_key, prefix in zip(map_paths, map_keys, map_names):
+        labels = object_labels(seg_path, seg_key,
+                               map_path, map_key, prefix,
+                               tmp_folder, target, max_jobs)
+        data.append(labels[:, None])
+
+    col_names = ['label_id'] + map_names
+    data = np.concatenate([label_ids[:, None]] + data, axis=0)
+    write_csv(map_out, data, col_names)
diff --git a/scripts/attributes/master.py b/scripts/attributes/master.py
index fa82a99..f02335a 100644
--- a/scripts/attributes/master.py
+++ b/scripts/attributes/master.py
@@ -2,6 +2,7 @@ import os
 import h5py
 
 from .base_attributes import base_attributes
+from .map_objects import map_objects
 from ..files import get_h5_path_from_xml
 
 
@@ -10,7 +11,6 @@ def get_seg_path(folder, name, key):
     path = get_h5_path_from_xml(xml_path, return_absolute_path=True)
     assert os.path.exists(path), path
     with h5py.File(path, 'r') as f:
-        print(path)
         assert key in f, "%s not in %s" % (key, str(list(f.keys())))
     return path
 
@@ -26,17 +26,23 @@ def make_cell_tables(folder, name, tmp_folder, resolution,
 
     # make the basic attributes table
     base_out = os.path.join(table_folder, 'default.csv')
-    base_attributes(seg_path, seg_key, base_out, resolution,
-                    tmp_folder, target=target, max_jobs=max_jobs,
-                    correct_anchors=True)
+    label_ids = base_attributes(seg_path, seg_key, base_out, resolution,
+                                tmp_folder, target=target, max_jobs=max_jobs,
+                                correct_anchors=True)
 
-    # TODO
     # make table with mapping to other objects
-    # nuclei
-    # cellular models
+    # nuclei, cellular models (TODO), ...
+    map_out = os.path.join(table_folder, 'objects.csv')
+    map_paths = [get_seg_path(folder, 'em-segmented-nuclei-labels')]
+    map_keys = [seg_key]
+    map_names = ['nucleus_id']
+    map_objects(label_ids, seg_path, seg_key, map_out,
+                map_paths, map_keys, map_names,
+                tmp_folder, target, max_jobs)
 
     # TODO additional tables:
-    # gene mapping
+    # regions / semantics
+    # gene expression
     # ???
 
 
@@ -55,9 +61,10 @@ def make_nucleus_tables(folder, name, tmp_folder, resolution,
                     tmp_folder, target=target, max_jobs=max_jobs,
                     correct_anchors=True)
 
-    # TODO
+    # TODO do we need this for nuclei as well ?
     # make table with mapping to other objects
-    # cells
+    # cells, ...
 
     # TODO additional tables:
     # kimberly's nucleus attributes
+    # ???
diff --git a/scripts/attributes/util.py b/scripts/attributes/util.py
new file mode 100644
index 0000000..ec26dc8
--- /dev/null
+++ b/scripts/attributes/util.py
@@ -0,0 +1,10 @@
+import csv
+
+
+def write_csv(output_path, data, col_names):
+    assert data.shape[1] == len(col_names), "%i %i" % (data.shape[1],
+                                                       len(col_names))
+    with open(output_path, 'w', newline='') as f:
+        writer = csv.writer(f, delimiter='\t')
+        writer.writerow(col_names)
+        writer.writerows(data)
diff --git a/scripts/export/export_segmentation.py b/scripts/export/export_segmentation.py
index e776ca1..6037980 100644
--- a/scripts/export/export_segmentation.py
+++ b/scripts/export/export_segmentation.py
@@ -53,7 +53,8 @@ def downscale(path, in_key, out_key,
         raise RuntimeError("Downscaling the segmentation failed")
 
 
-def export_segmentation(paintera_path, paintera_key, folder, new_folder, name, resolution, tmp_folder):
+def export_segmentation(paintera_path, paintera_key, folder, new_folder, name, resolution,
+                        tmp_folder, target='slurm', max_jobs=200):
     """ Export a segmentation from paintera project to bdv file and
     compute segment lut for previous segmentation.
 
@@ -66,9 +67,6 @@ def export_segmentation(paintera_path, paintera_key, folder, new_folder, name, r
         resolution: resolution [z, y, x] in micrometer
         tmp_folder: folder for temporary files
     """
-    # TODO should make this a param
-    max_jobs = 250
-    target = 'slurm'
 
     tmp_path = os.path.join(tmp_folder, 'data.n5')
     tmp_key = 'seg'
-- 
GitLab