From 0d49ef324f4403a0221504749dc89d8e53609549 Mon Sep 17 00:00:00 2001
From: Constantin Pape <c.pape@gmx.net>
Date: Fri, 4 Oct 2019 15:46:45 +0200
Subject: [PATCH] Add vc assignment implementation and update gene overlap
 assignment scripts

---
 scripts/attributes/genes.py                   |  25 ++-
 scripts/attributes/master.py                  |  16 +-
 scripts/extension/attributes/__init__.py      |   1 +
 scripts/extension/attributes/genes.py         |  89 +--------
 scripts/extension/attributes/genes_impl.py    | 108 +++++++++++
 .../extension/attributes/vc_assignments.py    | 107 +++++++++++
 .../attributes/vc_assignments_impl.py         | 181 ++++++++++++++++++
 update_registration.py                        |   7 +-
 8 files changed, 440 insertions(+), 94 deletions(-)
 create mode 100644 scripts/extension/attributes/genes_impl.py
 create mode 100644 scripts/extension/attributes/vc_assignments.py
 create mode 100644 scripts/extension/attributes/vc_assignments_impl.py

diff --git a/scripts/attributes/genes.py b/scripts/attributes/genes.py
index 2f007ce..fba1551 100755
--- a/scripts/attributes/genes.py
+++ b/scripts/attributes/genes.py
@@ -8,10 +8,11 @@ import h5py
 from tqdm import tqdm
 
 from ..extension.attributes import GenesLocal, GenesSlurm
+from ..extension.attributes import VCAssignmentsLocal, VCAssignmentsSlurm
 
 
-def write_genes_table(segm_file, genes_file, table_file, labels,
-                      tmp_folder, target, n_threads=8):
+def gene_assignment_table(segm_file, genes_file, table_file, labels,
+                          tmp_folder, target, n_threads=8):
     task = GenesSlurm if target == 'slurm' else GenesLocal
     seg_dset = 't00000/s00/4/cells'
 
@@ -35,6 +36,26 @@ def write_genes_table(segm_file, genes_file, table_file, labels,
         raise RuntimeError("Computing gene expressions failed")
 
 
+def vc_assignment_table(seg_path, vc_vol_path, vc_expression_path,
+                        med_expression_path, output_path,
+                        tmp_folder, target, n_threads=8):
+    task = VCAssignmentsSlurm if target == 'slurm' else VCAssignmentsLocal
+
+    config_folder = os.path.join(tmp_folder, 'configs')
+    config = task.default_task_config()
+    # this is very ram hungry because we load all the genes at once
+    config.update({'threads_per_job': n_threads, 'mem_limit': 256})
+    with open(os.path.join(config_folder, 'vc_assignments.config'), 'w') as f:
+        json.dump(config, f)
+
+    t = task(tmp_folder=tmp_folder, config_dir=config_folder, max_jobs=1,
+             segmentation_path=seg_path, vc_volume_path=vc_vol_path,
+             med_expression_path=med_expression_path, output_path=output_path)
+    ret = luigi.build([t], local_scheduler=True)
+    if not ret:
+        raise RuntimeError("Computing gene expressions failed")
+
+
 def find_nth(string, substring, n):
     if (n == 1):
         return string.find(substring)
diff --git a/scripts/attributes/master.py b/scripts/attributes/master.py
index 105404f..4b9aaf2 100644
--- a/scripts/attributes/master.py
+++ b/scripts/attributes/master.py
@@ -3,7 +3,7 @@ import h5py
 
 from .base_attributes import base_attributes, propagate_attributes
 from .cell_nucleus_mapping import map_cells_to_nuclei
-from .genes import write_genes_table
+from .genes import gene_assignment_table, vc_assignment_table
 from .morphology import write_morphology_cells, write_morphology_nuclei
 from .region_attributes import region_attributes
 from .cilia_attributes import cilia_morphology
@@ -46,8 +46,18 @@ def make_cell_tables(old_folder, folder, name, tmp_folder, resolution,
     if not os.path.exists(aux_gene_path):
         raise RuntimeError("Can't find auxiliary gene file @ %s" % aux_gene_path)
     gene_out = os.path.join(table_folder, 'genes.csv')
-    write_genes_table(seg_path, aux_gene_path, gene_out, label_ids,
-                      tmp_folder, target)
+    gene_assignment_table(seg_path, aux_gene_path, gene_out, label_ids,
+                          tmp_folder, target)
+
+    # make table with gene mapping via VCs
+    vc_vol_path = os.path.join('segmentations', 'prospr-6dpf-1-whole-virtual-cells-labels.xml')
+    vc_vol_path = get_h5_path_from_xml(vc_vol_path, return_absolute_path=True)
+    vc_expression_path = os.path.join('tables', 'prospr-6dpf-1-whole-virtual-cells-labels', 'profile_clust_curated.csv')
+    med_expression_path = gene_out
+    vc_out = os.path.join(table_folder, 'vc_assignments.csv')
+    vc_assignment_table(seg_path, vc_vol_path, vc_expression_path,
+                        med_expression_path, vc_out,
+                        tmp_folder, target)
 
     # make table with morphology
     morpho_out = os.path.join(table_folder, 'morphology.csv')
diff --git a/scripts/extension/attributes/__init__.py b/scripts/extension/attributes/__init__.py
index 4a52cda..c885d43 100644
--- a/scripts/extension/attributes/__init__.py
+++ b/scripts/extension/attributes/__init__.py
@@ -1,2 +1,3 @@
 from .genes import GenesLocal, GenesSlurm
+from .vc_assignments import VCAssignmentsLocal, VCAssignmentsSlurm
 from .workflow import MorphologyWorkflow
diff --git a/scripts/extension/attributes/genes.py b/scripts/extension/attributes/genes.py
index 4424968..e5683cc 100644
--- a/scripts/extension/attributes/genes.py
+++ b/scripts/extension/attributes/genes.py
@@ -3,18 +3,14 @@
 import os
 import sys
 import json
-import csv
-from concurrent import futures
 
 import luigi
 import numpy as np
-from vigra.analysis import extractRegionFeatures
-from vigra.sampling import resize
 
-import cluster_tools.utils.volume_utils as vu
 import cluster_tools.utils.function_utils as fu
 from cluster_tools.utils.task_utils import DummyTask
 from cluster_tools.cluster_tasks import SlurmTask, LocalTask
+from .genes_impl import gene_assignments
 
 #
 # Gene Attribute Tasks
@@ -82,66 +78,6 @@ class GenesSlurm(GenesBase, SlurmTask):
 # Implementation
 #
 
-
-def get_sizes_and_bbs(data):
-    # compute the relevant vigra region features
-    features = extractRegionFeatures(data.astype('float32'), data.astype('uint32'),
-                                     features=['Coord<Maximum >', 'Coord<Minimum >', 'Count'])
-
-    # extract sizes from features
-    cell_sizes = features['Count'].squeeze().astype('uint64')
-
-    # compute bounding boxes from features
-    mins = features['Coord<Minimum >'].astype('uint32')
-    maxs = features['Coord<Maximum >'].astype('uint32') + 1
-    cell_bbs = [tuple(slice(mi, ma) for mi, ma in zip(min_, max_))
-                for min_, max_ in zip(mins, maxs)]
-    return cell_sizes, cell_bbs
-
-
-def get_cell_expression(segmentation, all_genes, n_threads):
-    num_genes = all_genes.shape[0]
-    # NOTE we need to recalculate the unique labels here, beacause we might not
-    # have all labels due to donwsampling
-    labels = np.unique(segmentation)
-    cells_expression = np.zeros((len(labels), num_genes), dtype='float32')
-    cell_sizes, cell_bbs = get_sizes_and_bbs(segmentation)
-
-    def compute_expressions(cell_idx, cell_label):
-        # get size and boundinng box of this cell
-        cell_size = cell_sizes[cell_label]
-        bb = cell_bbs[cell_label]
-        # get the cell mask and the gene expression in bounding box
-        cell_masked = segmentation[bb] == cell_label
-        genes_in_cell = all_genes[(slice(None),) + bb]
-        # accumulate the gene expression channels over the cell mask
-        gene_expr_sum = np.sum(genes_in_cell[:, cell_masked] > 0, axis=1)
-        # divide by the cell size and write result
-        cells_expression[cell_idx] = gene_expr_sum / cell_size
-
-    with futures.ThreadPoolExecutor(n_threads) as tp:
-        tasks = [tp.submit(compute_expressions, cell_idx, cell_label)
-                 for cell_idx, cell_label in enumerate(labels) if cell_label != 0]
-        [t.result() for t in tasks]
-    return labels, cells_expression
-
-
-def write_genes_table(output_path, expression, gene_names, labels, avail_labels):
-    n_labels = len(labels)
-    n_cols = len(gene_names) + 1
-
-    data = np.zeros((n_labels, n_cols), dtype='float32')
-    data[:, 0] = labels
-    data[avail_labels, 1:] = expression
-
-    col_names = ['label_id'] + gene_names
-    assert data.shape[1] == len(col_names)
-    with open(output_path, 'w') as f:
-        csv_writer = csv.writer(f, delimiter='\t')
-        csv_writer.writerow(col_names)
-        csv_writer.writerows(data)
-
-
 def genes(job_id, config_path):
 
     fu.log("start processing job %i" % job_id)
@@ -157,28 +93,9 @@ def genes(job_id, config_path):
     output_path = config['output_path']
     n_threads = config.get('threads_per_job', 1)
 
-    fu.log("Loading segmentation, labels and gene-data")
-    # load segmentation, labels and genes
-    with vu.file_reader(segmentation_path, 'r') as f:
-        segmentation = f[segmentation_key][:]
     labels = np.load(labels_path)
-
-    genes_dset = 'genes'
-    names_dset = 'gene_names'
-    with vu.file_reader(genes_path, 'r') as f:
-        ds = f[genes_dset]
-        gene_shape = ds.shape[1:]
-        all_genes = ds[:]
-        gene_names = [i.decode('utf-8') for i in f[names_dset]]
-
-    # resize the segmentation to gene space
-    segmentation = resize(segmentation.astype("float32"),
-                          shape=gene_shape, order=0).astype('uint16')
-    fu.log("Compute gene expression")
-    avail_labels, expression = get_cell_expression(segmentation, all_genes, n_threads)
-
-    fu.log('Save results to %s' % output_path)
-    write_genes_table(output_path, expression, gene_names, labels, avail_labels)
+    gene_assignments(segmentation_path, segmentation_key,
+                     genes_path, labels, output_path, n_threads)
     fu.log_job_success(job_id)
 
 
diff --git a/scripts/extension/attributes/genes_impl.py b/scripts/extension/attributes/genes_impl.py
new file mode 100644
index 0000000..8ca2704
--- /dev/null
+++ b/scripts/extension/attributes/genes_impl.py
@@ -0,0 +1,108 @@
+# implementation for the tasks in 'genes.py', can be called standalone
+import csv
+from concurrent import futures
+
+import h5py
+import numpy as np
+from vigra.analysis import extractRegionFeatures
+from vigra.sampling import resize
+
+
+def get_sizes_and_bbs(data):
+    # compute the relevant vigra region features
+    features = extractRegionFeatures(data.astype('float32'), data.astype('uint32'),
+                                     features=['Coord<Maximum >', 'Coord<Minimum >', 'Count'])
+
+    # extract sizes from features
+    cell_sizes = features['Count'].squeeze().astype('uint64')
+
+    # compute bounding boxes from features
+    mins = features['Coord<Minimum >'].astype('uint32')
+    maxs = features['Coord<Maximum >'].astype('uint32') + 1
+    cell_bbs = [tuple(slice(mi, ma) for mi, ma in zip(min_, max_))
+                for min_, max_ in zip(mins, maxs)]
+    return cell_sizes, cell_bbs
+
+
+def get_cell_expression(segmentation, all_genes, n_threads):
+    num_genes = all_genes.shape[0]
+    # NOTE we need to recalculate the unique labels here, beacause we might not
+    # have all labels due to donwsampling
+    labels = np.unique(segmentation)
+    cells_expression = np.zeros((len(labels), num_genes), dtype='float32')
+    cell_sizes, cell_bbs = get_sizes_and_bbs(segmentation)
+
+    def compute_expressions(cell_idx, cell_label):
+        # get size and boundinng box of this cell
+        cell_size = cell_sizes[cell_label]
+        bb = cell_bbs[cell_label]
+        # get the cell mask and the gene expression in bounding box
+        cell_masked = segmentation[bb] == cell_label
+        genes_in_cell = all_genes[(slice(None),) + bb]
+        # accumulate the gene expression channels over the cell mask
+        gene_expr_sum = np.sum(genes_in_cell[:, cell_masked] > 0, axis=1)
+        # divide by the cell size and write result
+        cells_expression[cell_idx] = gene_expr_sum / cell_size
+
+    with futures.ThreadPoolExecutor(n_threads) as tp:
+        tasks = [tp.submit(compute_expressions, cell_idx, cell_label)
+                 for cell_idx, cell_label in enumerate(labels) if cell_label != 0]
+        [t.result() for t in tasks]
+    return labels, cells_expression
+
+
+def write_genes_table(output_path, expression, gene_names, labels, avail_labels):
+    n_labels = len(labels)
+    n_cols = len(gene_names) + 1
+
+    data = np.zeros((n_labels, n_cols), dtype='float32')
+    data[:, 0] = labels
+    data[avail_labels, 1:] = expression
+
+    col_names = ['label_id'] + gene_names
+    assert data.shape[1] == len(col_names)
+    with open(output_path, 'w') as f:
+        csv_writer = csv.writer(f, delimiter='\t')
+        csv_writer.writerow(col_names)
+        csv_writer.writerows(data)
+
+
+def gene_assignments(segmentation_path, segmentation_key,
+                     genes_path, labels, output_path,
+                     n_threads):
+    """ Write a table with genes assigned to segmentation by overlap.
+
+    Arguments:
+        segmentation_path [str] - path to hdf5 file with the cell segmentation
+        segmentation_key [str] - path in file to the segmentation dataset.
+        genes_path [str] - path to hdf5 file with spatial gene expression.
+            We expect the datasets 'genes' and 'gene_names' to be present.
+        labels [np.ndarray] - cell id labels
+        output_path [str] - where to write the result table
+        n_threads [int] - number of threads used for the computation
+    """
+
+    with h5py.File(segmentation_path, 'r') as f:
+        segmentation = f[segmentation_key][:]
+
+    genes_dset = 'genes'
+    names_dset = 'gene_names'
+    with h5py.File(genes_path, 'r') as f:
+        ds = f[genes_dset]
+        gene_shape = ds.shape[1:]
+        all_genes = ds[:]
+        gene_names = [i.decode('utf-8') for i in f[names_dset]]
+
+    # resize the segmentation to gene space
+    segmentation = resize(segmentation.astype("float32"),
+                          shape=gene_shape, order=0).astype('uint16')
+    print("Compute gene expression ...")
+    avail_labels, expression = get_cell_expression(segmentation, all_genes, n_threads)
+
+    print('Save results to %s' % output_path)
+    write_genes_table(output_path, expression, gene_names, labels, avail_labels)
+
+
+# TODO write argument parser
+if __name__ == '__main__':
+    pass
diff --git a/scripts/extension/attributes/vc_assignments.py b/scripts/extension/attributes/vc_assignments.py
new file mode 100644
index 0000000..d2b57a1
--- /dev/null
+++ b/scripts/extension/attributes/vc_assignments.py
@@ -0,0 +1,107 @@
+#! /bin/python
+
+import os
+import sys
+import json
+
+import luigi
+import numpy as np
+
+import cluster_tools.utils.function_utils as fu
+from cluster_tools.utils.task_utils import DummyTask
+from cluster_tools.cluster_tasks import SlurmTask, LocalTask
+from .vc_assignments_impl import vc_assignments as vc_assignments_impl
+
+#
+# Gene Attribute Tasks
+#
+
+
+class VCAssignmentsBase(luigi.Task):
+    """ VCAssignments base class
+    """
+
+    task_name = 'vc_assignments'
+    src_file = os.path.abspath(__file__)
+    allow_retry = False
+
+    # input volumes and graph
+    segmentation_path = luigi.Parameter()
+    vc_volume_path = luigi.Parameter()
+    vc_expression_path = luigi.Parameter()
+    med_expression_path = luigi.Parameter()
+    output_path = luigi.Parameter()
+    #
+    dependency = luigi.TaskParameter(default=DummyTask())
+
+    def requires(self):
+        return self.dependency
+
+    def run_impl(self):
+        # get the global config and init configs
+        shebang = self.global_config_values()[0]
+        self.init(shebang)
+
+        # load the task config
+        config = self.get_task_config()
+
+        # update the config with input and graph paths and keys
+        # as well as block shape
+        config.update({'segmentation_path': self.segmentation_path,
+                       'vc_volume_path': self.vc_volume_path,
+                       'vc_expression_path': self.vc_expression_path,
+                       'med_expression_path': self.med_expression_path,
+                       'output_path': self.output_path})
+
+        # prime and run the job
+        self.prepare_jobs(1, None, config)
+        self.submit_jobs(1)
+
+        # wait till jobs finish and check for job success
+        self.wait_for_jobs()
+        self.check_jobs(1)
+
+
+class VCAssignmentsLocal(VCAssignmentsBase, LocalTask):
+    """ VCAssignments on local machine
+    """
+    pass
+
+
+class VCAssignmentsSlurm(VCAssignmentsBase, SlurmTask):
+    """ VCAssignments on slurm cluster
+    """
+    pass
+
+
+#
+# Implementation
+#
+
+def vc_assignments(job_id, config_path):
+
+    fu.log("start processing job %i" % job_id)
+    fu.log("reading config from %s" % config_path)
+
+    # get the config
+    with open(config_path) as f:
+        config = json.load(f)
+
+    segmentation_path = config['segmentation_path']
+    vc_volume_path = config['vc_assignments_path']
+    vc_expression_path = config['vc_expression_path']
+    med_expression_path = config['med_expression_path']
+
+    output_path = config['output_path']
+    n_threads = config.get('threads_per_job', 1)
+
+    vc_assignments_impl(segmentation_path, vc_volume_path, vc_expression_path,
+                        med_expression_path, output_path, n_threads)
+    fu.log_job_success(job_id)
+
+
+if __name__ == '__main__':
+    path = sys.argv[1]
+    assert os.path.exists(path), path
+    job_id = int(os.path.split(path)[1].split('.')[0].split('_')[-1])
+    vc_assignments(job_id, path)
diff --git a/scripts/extension/attributes/vc_assignments_impl.py b/scripts/extension/attributes/vc_assignments_impl.py
new file mode 100644
index 0000000..791be04
--- /dev/null
+++ b/scripts/extension/attributes/vc_assignments_impl.py
@@ -0,0 +1,181 @@
+# implementation for the tasks in 'vc_assignments.py', can be called standalone
+import os
+import csv
+from concurrent import futures
+
+import h5py
+import argparse
+import numpy as np
+from vigra.analysis import extractRegionFeatures
+from vigra.sampling import resize
+from vigra.filters import distanceTransform
+
+
+def add_path_if_needed(file_path, dir_path):
+    return file_path if os.path.exists(file_path) else os.path.join(dir_path, file_path)
+
+
+def get_common_genes(vc_genes_file_path, cells_gene_expression, med_gene_names):
+    med_gene_indices = []
+    vc_gene_indices = []
+    common_gene_names = []
+    med_gene_names_lowercase = [i.lower().split('-')[0] for i in med_gene_names]
+    # get the names of genes used for vc's
+    with open(vc_genes_file_path) as csv_file:
+        csv_reader = csv.DictReader(csv_file, delimiter=',')
+        vc_gene_names = csv_reader.fieldnames
+    # find a subset of genes both used for vc's and available as MEDs
+    for i in range(len(vc_gene_names)):
+        name = vc_gene_names[i].split('--')[0]
+        if name.lower() in med_gene_names_lowercase:
+            med_gene_indices.append(med_gene_names_lowercase.index(name.lower()))
+            vc_gene_indices.append(i)
+            common_gene_names.append(name)
+    # from expression_by_overlap assignment extract only the subset genes
+    cells_expression_subset = np.take(cells_gene_expression, med_gene_indices,
+                                      axis=1)
+    # from vcs_expression extract only the subset genes
+    vc_expression_subset = np.loadtxt(vc_genes_file_path, delimiter=',',
+                                      skiprows=1, usecols=vc_gene_indices)
+    # add the null vc with no expression
+    vc_expression_subset = np.insert(vc_expression_subset, 0,
+                                     np.zeros(len(vc_gene_indices)),
+                                     axis=0)
+    print(len(common_gene_names), 'common genes found in VCs and MEDs')
+    return cells_expression_subset, vc_expression_subset, common_gene_names
+
+
+def get_bbs(data, offset):
+    shape = np.array(data.shape)
+    # compute the relevant vigra region features
+    # beware: for the absent labels the results are ridiculous
+    features = extractRegionFeatures(data.astype('float32'), data.astype('uint32'),
+                                     features=['Coord<Maximum >', 'Coord<Minimum >'])
+    # compute bounding boxes from features
+    mins = features['Coord<Minimum >'] - offset
+    maxs = features['Coord<Maximum >'] + offset + 1
+    # to prevent 'out of range' due to offsets
+    mins[np.where(mins < 0)] = 0
+    maxs[np.where(maxs > shape)] = shape[np.where(maxs > shape)[1]]
+    # get a bb for each cell
+    cell_bbs = [tuple(slice(mi, ma) for mi, ma in zip(min_, max_))
+                for min_, max_ in zip(np.uint32(mins), np.uint32(maxs))]
+    return cell_bbs
+
+
+def get_distances(em_data, vc_data, cells_expression, vc_expression, n_threads,
+                  offset=10):
+    num_cells = cells_expression.shape[0]
+    # some labels might be lost due to downsampling
+    avail_cells = np.unique(em_data)
+    num_vcs = np.max(vc_data) + 1
+    distance_matrix = np.full((num_cells, num_vcs), np.nan)
+    bbs = get_bbs(em_data, offset)
+
+    def cell_ids(cell):
+        if cell == 0:
+            return
+
+        bb = bbs[cell]
+        cell_mask = (em_data[bb] == cell).astype("uint32")
+        dist = distanceTransform(cell_mask)
+        vc_roi = vc_data[bb]
+        vc_candidate_list = np.unique(vc_roi).astype('int')
+        vc_list = [vc for vc in vc_candidate_list
+                   if np.min(dist[vc_roi == vc]) <= offset]
+        cell_genes = cells_expression[cell]
+        if 0 not in vc_list:
+            vc_list = np.append(vc_list, 0)
+        vc_genes = vc_expression[vc_list]
+        # calculate the genetic distance between the cell and the vcs
+        distance = np.sum(np.abs(cell_genes - vc_genes), axis=1)
+        distance_matrix[cell][vc_list] = distance
+
+    with futures.ThreadPoolExecutor(n_threads) as tp:
+        tasks = [tp.submit(get_distances, cell)for cell in avail_cells]
+        [t.result() for t in tasks]
+
+    return distance_matrix
+
+
+def assign_vc(distances, vc_expression):
+    num_cells = distances.shape[0]
+    # assign to 0 if no vcs were found at all
+    assignments = [0 if np.all(np.isnan(distances[cell]))
+                   else np.nanargmin(distances[cell])
+                   for cell in range(num_cells)]
+    cells_expr = vc_expression[assignments]
+    cells_expr = np.insert(cells_expr, 0, np.arange(num_cells), axis=1)
+    return cells_expr
+
+
+def vc_assignments(segm_volume_file, vc_volume_file, vc_expr_file,
+                   cells_med_expr_table, output_gene_table, n_threads):
+    em_dset = 't00000/s00/4/cells'
+    cm_dset = 't00000/s00/0/cells'
+    # volume file for vc's (generated from CellModels_coordinates)
+    with h5py.File(vc_volume_file, 'r') as f:
+        vc_data = f[cm_dset][:]
+
+    # downsample segmentation data to the same resolution as gene data
+    with h5py.File(segm_volume_file, 'r') as f:
+        segm_data = f[em_dset][:]
+    downsampled_segm_data = resize(segm_data.astype("float32"), shape=vc_data.shape,
+                                   order=0).astype('uint16')
+
+    # the table with cell expression by overlap
+    with open(cells_med_expr_table) as csv_file:
+        csv_reader = csv.DictReader(csv_file, delimiter='\t')
+        med_gene_names = csv_reader.fieldnames[1:]
+    cells_expression = np.loadtxt(cells_med_expr_table, delimiter='\t',
+                                  skiprows=1)
+    # get the genes that were both used for vcs and are in med files
+    cells_expression_subset, vc_expression_subset,  common_gene_names = \
+        get_common_genes(vc_expr_file, cells_expression[:, 1:], med_gene_names)
+
+    # get the genetic distance from cells to surrounding vcs
+    dist_matrix = get_distances(downsampled_segm_data, vc_data,
+                                cells_expression_subset,
+                                vc_expression_subset, n_threads)
+    # assign the cells to the genetically closest vcs
+    cell_assign = assign_vc(dist_matrix, vc_expression_subset)
+    # write down a new table
+    col_names = ['label_id'] + common_gene_names
+    assert cell_assign.shape[1] == len(col_names)
+    with open(output_gene_table, 'w') as f:
+        csv_writer = csv.writer(f, delimiter='\t')
+        csv_writer.writerow(col_names)
+        csv_writer.writerows(cell_assign)
+
+
+if __name__ == '__main__':
+
+    # to make life easier for me debugging ;)
+    platy_data_path = '/g/arendt/EM_6dpf_segmentation/platy-browser-data/data'
+    gene_data_path = '/g/kreshuk/zinchenk/cell_match/data/genes'
+    table_path = 'tables/sbem-6dpf-1-whole-segmented-cells-labels/genes.csv'
+    segm_path = 'segmentations/sbem-6dpf-1-whole-segmented-cells-labels.h5'
+
+    parser = argparse.ArgumentParser(description='Assign cells to genetically closest VCs')
+    parser.add_argument('vc_volume_file', type=str,
+                        help='the h5 file with VC labels')
+    parser.add_argument('vc_profile_file', type=str,
+                        help='table of expression by VC')
+    parser.add_argument('output_file', type=str,
+                        help='the files with cell expression assigned by VC')
+    parser.add_argument('--ov_expr_version', type=str, default='0.5.4',
+                        help='the version of platy data to take expression by overlap from')
+    parser.add_argument('--segm_version', type=str, default='0.3.1',
+                        help='the version of platy data to take segmentation from')
+    args = parser.parse_args()
+
+    gene_table_file = os.path.join(platy_data_path, args.ov_expr_version, table_path)
+    segment_file_path = os.path.join(platy_data_path, args.segm_version, segm_path)
+    vc_volume_file = add_path_if_needed(args.vc_volume_file, gene_data_path)
+    vc_profile_file = add_path_if_needed(args.vc_profile_file, gene_data_path)
+    output_file = add_path_if_needed(args.output_file, gene_data_path)
+
+    # number of threads hard-coded for now
+    n_threads = 8
+    vc_assignments(segment_file_path, vc_volume_file, vc_profile_file,
+                   gene_table_file, output_file, n_threads)
diff --git a/update_registration.py b/update_registration.py
index 4ecc8d8..8491e2a 100755
--- a/update_registration.py
+++ b/update_registration.py
@@ -18,7 +18,7 @@ from scripts.release_helper import add_version
 from scripts.extension.registration import ApplyRegistrationLocal, ApplyRegistrationSlurm
 from scripts.default_config import get_default_shebang
 from scripts.attributes.base_attributes import base_attributes
-from scripts.attributes.genes import create_auxiliary_gene_file, write_genes_table
+from scripts.attributes.genes import create_auxiliary_gene_file, gene_assignment_table
 from scripts.util import add_max_id
 
 
@@ -176,8 +176,9 @@ def update_prospr(new_folder, input_folder, transformation_file, target, max_job
         assert os.path.islink(out_path), out_path
         print("Remove link to previous gene table:", out_path)
         os.unlink(out_path)
-    write_genes_table(seg_path, aux_out_path, out_path,
-                      labels, tmp_folder, target)
+    gene_assignment_table(seg_path, aux_out_path, out_path,
+                          labels, tmp_folder, target)
+    # TODO update the vc based gene assignments as well
 
     # register virtual cells
     vc_name = 'prospr-6dpf-1-whole-virtual-cells-labels'
-- 
GitLab