Skip to content
Snippets Groups Projects
Commit 2aa06f51 authored by Constantin Pape's avatar Constantin Pape
Browse files

Wrap gene expression in cluster task; vectorize and parallelize computation

parent 7ddc3228
No related branches found
No related tags found
No related merge requests found
import csv
import h5py
import os
import json
import luigi
import numpy as np
from vigra.analysis import extractRegionFeatures
from vigra.sampling import resize
# TODO
# wrap this in a cluster_tools task in order to run remotely
# fix blatant inefficiencis (size loop)
# make test to check against original table
def get_sizes_and_bbs(data):
# compute the relevant vigra region features
features = extractRegionFeatures(data.astype('float32'), data.astype('uint32'),
features=['Coord<Maximum >', 'Coord<Minimum >', 'Count'])
# extract sizes from features
cell_sizes = features['Count'].squeeze().astype('uint64')
# compute bounding boxes from features
mins = features['Coord<Minimum >'].astype('uint32')
maxs = features['Coord<Maximum >'].astype('uint32') + 1
cell_bbs = [tuple(slice(mi, ma) for mi, ma in zip(min_, max_))
for min_, max_ in zip(mins, maxs)]
return cell_sizes, cell_bbs
def get_cell_expression(segm_data, all_genes):
num_genes = all_genes.shape[0]
labels = list(np.unique(segm_data))
cells_expression = np.zeros((len(labels), num_genes), dtype='float32')
cell_sizes, cell_bbs = get_sizes_and_bbs(segm_data)
for cell_idx in range(len(labels)):
cell_label = labels[cell_idx]
if cell_label == 0:
continue
cell_size = cell_sizes[cell_label]
bb = cell_bbs[cell_label]
cell_masked = (segm_data[bb] == cell_label)
genes_in_cell = all_genes[tuple([slice(0, None)] + list(bb))]
for gene in range(num_genes):
gene_expr = genes_in_cell[gene]
gene_expr_sum = np.sum(gene_expr[cell_masked] > 0)
cells_expression[cell_idx, gene] = gene_expr_sum / cell_size
return labels, cells_expression
def write_genes_table(segm_file, genes_file, table_file, labels):
dset = 't00000/s00/4/cells'
new_shape = (570, 518, 550)
genes_dset = 'genes'
names_dset = 'gene_names'
with h5py.File(segm_file, 'r') as f:
segment_data = f[dset][:]
# TODO loading the whole thing into ram takes a lot of memory
with h5py.File(genes_file, 'r') as f:
all_genes = f[genes_dset][:]
gene_names = [i.decode('utf-8') for i in f[names_dset]]
num_genes = len(gene_names)
downsampled_data = resize(segment_data.astype("float32"), shape=new_shape, order=0).astype('uint16')
avail_labels, expression = get_cell_expression(downsampled_data, all_genes)
with open(table_file, 'w') as genes_table:
csv_writer = csv.writer(genes_table, delimiter='\t')
_ = csv_writer.writerow(['label_id'] + gene_names)
for label in labels:
if label in avail_labels:
idx = avail_labels.index(label)
_ = csv_writer.writerow([label] + list(expression[idx]))
else:
_ = csv_writer.writerow([label] + [0] * num_genes)
from ..extension.attributes import GenesLocal, GenesSlurm
def write_genes_table(segm_file, genes_file, table_file, labels,
tmp_folder, target, n_threads=8):
task = GenesSlurm if target == 'slurm' else GenesLocal
seg_dset = 't00000/s00/4/cells'
# TODO would be good not to hard-code this
gene_shape = (570, 518, 550)
config_folder = os.path.join(tmp_folder, 'configs')
config = task.default_task_config()
# this is very ram hungry because we load all the genes at once
config.update({'threads_per_job': n_threads, 'mem_limit': 256})
with open(os.path.join(config_folder, 'genes.config'), 'w') as f:
json.dump(config, f)
# we need to serialize the labels so that they can be loaded by the luigi task
labels_path = os.path.join(tmp_folder, 'unique_labels.npy')
np.save(labels_path, labels)
t = task(tmp_folder=tmp_folder, config_dir=config_folder, max_jobs=1,
segmentation_path=segm_file, segmentation_key=seg_dset,
genes_path=genes_file, labels_path=labels_path,
output_path=table_file, gene_shape=gene_shape)
ret = luigi.build([t], local_scheduler=True)
if not ret:
raise RuntimeError("Computing gene expressions failed")
......@@ -50,7 +50,8 @@ def make_cell_tables(folder, name, tmp_folder, resolution,
if not os.path.exists(aux_gene_path):
raise RuntimeError("Can't find auxiliary gene file")
gene_out = os.path.join(table_folder, 'genes.csv')
write_genes_table(seg_path, aux_gene_path, gene_out, label_ids)
write_genes_table(seg_path, aux_gene_path, gene_out, label_ids,
tmp_folder, target)
# make table with morphology
morpho_out = os.path.join(table_folder, 'morphology.csv')
......
from .genes import GenesLocal, GenesSlurm
#! /bin/python
import os
import sys
import json
import csv
from concurrent import futures
import luigi
import numpy as np
from vigra.analysis import extractRegionFeatures
from vigra.sampling import resize
import cluster_tools.utils.volume_utils as vu
import cluster_tools.utils.function_utils as fu
from cluster_tools.utils.task_utils import DummyTask
from cluster_tools.cluster_tasks import SlurmTask, LocalTask
#
# Gene Attribute Tasks
#
class GenesBase(luigi.Task):
""" Genes base class
"""
task_name = 'genes'
src_file = os.path.abspath(__file__)
allow_retry = False
# input volumes and graph
segmentation_path = luigi.Parameter()
segmentation_key = luigi.Parameter()
genes_path = luigi.Parameter()
labels_path = luigi.Parameter()
output_path = luigi.Parameter()
gene_shape = luigi.ListParameter()
#
dependency = luigi.TaskParameter(default=DummyTask())
def requires(self):
return self.dependency
def run_impl(self):
# get the global config and init configs
shebang = self.global_config_values()[0]
self.init(shebang)
# load the task config
config = self.get_task_config()
# update the config with input and graph paths and keys
# as well as block shape
config.update({'segmentation_path': self.segmentation_path,
'segmentation_key': self.segmentation_key,
'genes_path': self.genes_path,
'output_path': self.output_path,
'labels_path': self.labels_path,
'gene_shape': self.gene_shape})
# prime and run the job
self.prepare_jobs(1, None, config)
self.submit_jobs(1)
# wait till jobs finish and check for job success
self.wait_for_jobs()
self.check_jobs(1)
class GenesLocal(GenesBase, LocalTask):
""" Genes on local machine
"""
pass
class GenesSlurm(GenesBase, SlurmTask):
""" Genes on slurm cluster
"""
pass
#
# Implementation
#
def get_sizes_and_bbs(data):
# compute the relevant vigra region features
features = extractRegionFeatures(data.astype('float32'), data.astype('uint32'),
features=['Coord<Maximum >', 'Coord<Minimum >', 'Count'])
# extract sizes from features
cell_sizes = features['Count'].squeeze().astype('uint64')
# compute bounding boxes from features
mins = features['Coord<Minimum >'].astype('uint32')
maxs = features['Coord<Maximum >'].astype('uint32') + 1
cell_bbs = [tuple(slice(mi, ma) for mi, ma in zip(min_, max_))
for min_, max_ in zip(mins, maxs)]
return cell_sizes, cell_bbs
def get_cell_expression(segmentation, all_genes, n_threads):
num_genes = all_genes.shape[0]
# NOTE we need to recalculate the unique labels here, beacause we might not
# have all labels due to donwsampling
labels = np.unique(segmentation)
cells_expression = np.zeros((len(labels), num_genes), dtype='float32')
cell_sizes, cell_bbs = get_sizes_and_bbs(segmentation)
def compute_expressions(cell_idx, cell_label):
# get size and boundinng box of this cell
cell_size = cell_sizes[cell_label]
bb = cell_bbs[cell_label]
# get the cell mask and the gene expression in bounding box
cell_masked = segmentation[bb] == cell_label
genes_in_cell = all_genes[(slice(None),) + bb]
# accumulate the gene expression channels over the cell mask
gene_expr_sum = np.sum(genes_in_cell[:, cell_masked] > 0, axis=1)
# divide by the cell size and write result
cells_expression[cell_idx] = gene_expr_sum / cell_size
with futures.ThreadPoolExecutor(n_threads) as tp:
tasks = [tp.submit(compute_expressions, cell_idx, cell_label)
for cell_idx, cell_label in enumerate(labels) if cell_label != 0]
[t.result() for t in tasks]
return labels, cells_expression
def write_genes_table(output_path, expression, gene_names, labels, avail_labels):
n_labels = len(labels)
n_cols = len(gene_names) + 1
data = np.zeros((n_labels, n_cols), dtype='float32')
data[:, 0] = labels
data[avail_labels, 1:] = expression
col_names = ['label_id'] + gene_names
assert data.shape[1] == len(col_names)
with open(output_path, 'w') as f:
csv_writer = csv.writer(f, delimiter='\t')
csv_writer.writerow(col_names)
csv_writer.writerows(data)
def genes(job_id, config_path):
fu.log("start processing job %i" % job_id)
fu.log("reading config from %s" % config_path)
# get the config
with open(config_path) as f:
config = json.load(f)
segmentation_path = config['segmentation_path']
segmentation_key = config['segmentation_key']
genes_path = config['genes_path']
labels_path = config['labels_path']
output_path = config['output_path']
gene_shape = tuple(config['gene_shape'])
n_threads = config.get('threads_per_job', 1)
fu.log("Loading segmentation, labels and gene-data")
# load segmentation, labels and genes
with vu.file_reader(segmentation_path, 'r') as f:
segmentation = f[segmentation_key][:]
labels = np.load(labels_path)
genes_dset = 'genes'
names_dset = 'gene_names'
with vu.file_reader(genes_path, 'r') as f:
all_genes = f[genes_dset][:]
gene_names = [i.decode('utf-8') for i in f[names_dset]]
# resize the segmentation to gene space
segmentation = resize(segmentation.astype("float32"),
shape=gene_shape, order=0).astype('uint16')
fu.log("Compute gene expression")
avail_labels, expression = get_cell_expression(segmentation, all_genes, n_threads)
fu.log('Save results to %s' % output_path)
write_genes_table(output_path, expression, gene_names, labels, avail_labels)
fu.log_job_success(job_id)
if __name__ == '__main__':
path = sys.argv[1]
assert os.path.exists(path), path
job_id = int(os.path.split(path)[1].split('.')[0].split('_')[-1])
genes(job_id, path)
import unittest
import sys
import os
import json
import sys
import numpy as np
from shutil import rmtree
sys.path.append('../..')
# check new version of gene mapping against original
class TestGeneAttributes(unittest.TestCase):
test_file = 'test_table.csv'
tmp_folder = 'tmp_genes'
test_file = 'tmp_genes/test_table.csv'
def tearDown(self):
def _tearDown(self):
try:
os.remove(self.test_file)
rmtree(self.tmp_folder)
except OSError:
pass
......@@ -22,6 +25,7 @@ class TestGeneAttributes(unittest.TestCase):
def test_genes(self):
from scripts.attributes.genes import write_genes_table
from scripts.extension.attributes import GenesLocal
from scripts.files import get_h5_path_from_xml
# load original genes table
......@@ -35,8 +39,19 @@ class TestGeneAttributes(unittest.TestCase):
genes_file = '../../data/0.0.0/misc/meds_all_genes.xml'
genes_file = get_h5_path_from_xml(genes_file)
table_file = self.test_file
# write the global config
config_folder = os.path.join(self.tmp_folder, 'configs')
os.makedirs(config_folder, exist_ok=True)
conf = GenesLocal.default_global_config()
shebang = '#! /g/kreshuk/pape/Work/software/conda/miniconda3/envs/cluster_env37/bin/python'
conf.update({'shebang': shebang})
with open(os.path.join(config_folder, 'global.config'), 'w') as f:
json.dump(conf, f)
print("Start computation ...")
write_genes_table(segm_file, genes_file, table_file, labels)
write_genes_table(segm_file, genes_file, table_file, labels,
self.tmp_folder, 'local', 8)
table = self.load_table(table_file)
# make sure new and old table agree
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment