From a6f91df44c7cfa2e418546888dbc9c4328ff4ce0 Mon Sep 17 00:00:00 2001
From: Constantin Pape <constantin.pape@iwr.uni-heidelberg.de>
Date: Fri, 23 Aug 2019 20:14:00 +0200
Subject: [PATCH] Add gene expression computation

---
 analysis/gene_expression.py    | 36 ++++++++++++++++++++++++++++++++++
 scripts/__init__.py            |  1 +
 scripts/analysis/__init__.py   |  1 +
 scripts/analysis/expression.py | 25 +++++++++++++++++++++++
 scripts/release_helper.py      |  6 ++++++
 5 files changed, 69 insertions(+)
 create mode 100644 analysis/gene_expression.py
 create mode 100644 scripts/analysis/expression.py

diff --git a/analysis/gene_expression.py b/analysis/gene_expression.py
new file mode 100644
index 0000000..becb990
--- /dev/null
+++ b/analysis/gene_expression.py
@@ -0,0 +1,36 @@
+#! /g/arendt/pape/miniconda3/envs/platybrowser/bin/python
+import argparse
+import os
+from scripts import get_latest_version
+from scripts.analysis import get_cells_expressing_genes
+
+
+def count_gene_expression(gene_names, threshold, version):
+
+    # path are hard-coded, so we need to change the pwd to '..'
+    os.chdir('..')
+    try:
+        if version == '':
+            version = get_latest_version()
+        # TODO enable using vc assignments once we have them on master
+        table_path = 'data/%s/tables/sbem-6dpf-1-whole-segmented-cells-labels/genes.csv' % version
+        ids = get_cells_expressing_genes(table_path, threshold, gene_names)
+        n = len(ids)
+        print("Found", n, "cells expressing:", ",".join(gene_names))
+    except Exception as e:
+        os.chdir('analysis')
+        raise e
+    os.chdir('analysis')
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Compute number of cells co-expressing genes.')
+    parser.add_argument('gene_names', type=str, nargs='+',
+                        help='Names of the genes for which to express co-expression.')
+    parser.add_argument('--threshold', type=float, default=.5,
+                        help='Threshold to count gene expression. Default is 0.5.')
+    parser.add_argument('--version', type=str, default='',
+                        help='Version of the platy browser data. Default is latest.')
+
+    args = parser.parse_args()
+    count_gene_expression(args.gene_names, args.threshold, args.version)
diff --git a/scripts/__init__.py b/scripts/__init__.py
index e69de29..44b5d02 100644
--- a/scripts/__init__.py
+++ b/scripts/__init__.py
@@ -0,0 +1 @@
+from .release_helper import get_latest_version
diff --git a/scripts/analysis/__init__.py b/scripts/analysis/__init__.py
index 8c6ed4e..f73e0c5 100644
--- a/scripts/analysis/__init__.py
+++ b/scripts/analysis/__init__.py
@@ -1 +1,2 @@
 from .counts import cell_counts
+from .expression import get_cells_expressing_genes
diff --git a/scripts/analysis/expression.py b/scripts/analysis/expression.py
new file mode 100644
index 0000000..e3c4f4f
--- /dev/null
+++ b/scripts/analysis/expression.py
@@ -0,0 +1,25 @@
+import numpy as np
+import pandas as pd
+
+
+def get_cells_expressing_genes(table_path, expression_threshold, gene_names):
+    if isinstance(gene_names, str):
+        gene_names = [gene_names]
+    if not isinstance(gene_names, list):
+        raise ValueError("Gene names must be a str or a list of strings")
+
+    table = pd.read_csv(table_path, sep='\t')
+    label_ids = table['label_id']
+
+    columns = table.columns
+    unmatched = set(gene_names) - set(columns)
+    if len(unmatched) > 0:
+        raise RuntimeError("Could not find gene names %s in table %s" % (", ".join(unmatched),
+                                                                         table_path))
+
+    # find logical and of columns expressing the genes
+    expressing = np.logical_and.reduce(tuple(table[name] > expression_threshold
+                                             for name in gene_names))
+    # get ids of columns expressing all genes
+    label_ids = label_ids[expressing].values
+    return label_ids
diff --git a/scripts/release_helper.py b/scripts/release_helper.py
index 7530cd3..a118929 100644
--- a/scripts/release_helper.py
+++ b/scripts/release_helper.py
@@ -121,3 +121,9 @@ def add_version(tag):
     versions.append(tag)
     with open(VERSION_FILE, 'w') as f:
         json.dump(versions, f)
+
+
+def get_latest_version():
+    with open(VERSION_FILE) as f:
+        versions = json.load(f)
+    return versions[-1]
-- 
GitLab