diff --git a/scripts/attributes/genes.py b/scripts/attributes/genes.py index 32341067892f4ec4c923759c9cb7c56a2011302e..58b2bf9ae1cb60b5f163541eb4bad6664d581858 100755 --- a/scripts/attributes/genes.py +++ b/scripts/attributes/genes.py @@ -1,76 +1,33 @@ -import csv -import h5py +import os +import json +import luigi import numpy as np -from vigra.analysis import extractRegionFeatures -from vigra.sampling import resize - - -# TODO -# wrap this in a cluster_tools task in order to run remotely -# fix blatant inefficiencis (size loop) -# make test to check against original table - - -def get_sizes_and_bbs(data): - # compute the relevant vigra region features - features = extractRegionFeatures(data.astype('float32'), data.astype('uint32'), - features=['Coord<Maximum >', 'Coord<Minimum >', 'Count']) - - # extract sizes from features - cell_sizes = features['Count'].squeeze().astype('uint64') - - # compute bounding boxes from features - mins = features['Coord<Minimum >'].astype('uint32') - maxs = features['Coord<Maximum >'].astype('uint32') + 1 - cell_bbs = [tuple(slice(mi, ma) for mi, ma in zip(min_, max_)) - for min_, max_ in zip(mins, maxs)] - return cell_sizes, cell_bbs - - -def get_cell_expression(segm_data, all_genes): - num_genes = all_genes.shape[0] - labels = list(np.unique(segm_data)) - cells_expression = np.zeros((len(labels), num_genes), dtype='float32') - cell_sizes, cell_bbs = get_sizes_and_bbs(segm_data) - for cell_idx in range(len(labels)): - cell_label = labels[cell_idx] - if cell_label == 0: - continue - cell_size = cell_sizes[cell_label] - bb = cell_bbs[cell_label] - cell_masked = (segm_data[bb] == cell_label) - genes_in_cell = all_genes[tuple([slice(0, None)] + list(bb))] - for gene in range(num_genes): - gene_expr = genes_in_cell[gene] - gene_expr_sum = np.sum(gene_expr[cell_masked] > 0) - cells_expression[cell_idx, gene] = gene_expr_sum / cell_size - return labels, cells_expression - - -def write_genes_table(segm_file, genes_file, table_file, labels): - dset = 't00000/s00/4/cells' - new_shape = (570, 518, 550) - genes_dset = 'genes' - names_dset = 'gene_names' - - with h5py.File(segm_file, 'r') as f: - segment_data = f[dset][:] - - # TODO loading the whole thing into ram takes a lot of memory - with h5py.File(genes_file, 'r') as f: - all_genes = f[genes_dset][:] - gene_names = [i.decode('utf-8') for i in f[names_dset]] - - num_genes = len(gene_names) - downsampled_data = resize(segment_data.astype("float32"), shape=new_shape, order=0).astype('uint16') - avail_labels, expression = get_cell_expression(downsampled_data, all_genes) - - with open(table_file, 'w') as genes_table: - csv_writer = csv.writer(genes_table, delimiter='\t') - _ = csv_writer.writerow(['label_id'] + gene_names) - for label in labels: - if label in avail_labels: - idx = avail_labels.index(label) - _ = csv_writer.writerow([label] + list(expression[idx])) - else: - _ = csv_writer.writerow([label] + [0] * num_genes) +from ..extension.attributes import GenesLocal, GenesSlurm + + +def write_genes_table(segm_file, genes_file, table_file, labels, + tmp_folder, target, n_threads=8): + task = GenesSlurm if target == 'slurm' else GenesLocal + + seg_dset = 't00000/s00/4/cells' + # TODO would be good not to hard-code this + gene_shape = (570, 518, 550) + + config_folder = os.path.join(tmp_folder, 'configs') + config = task.default_task_config() + # this is very ram hungry because we load all the genes at once + config.update({'threads_per_job': n_threads, 'mem_limit': 256}) + with open(os.path.join(config_folder, 'genes.config'), 'w') as f: + json.dump(config, f) + + # we need to serialize the labels so that they can be loaded by the luigi task + labels_path = os.path.join(tmp_folder, 'unique_labels.npy') + np.save(labels_path, labels) + + t = task(tmp_folder=tmp_folder, config_dir=config_folder, max_jobs=1, + segmentation_path=segm_file, segmentation_key=seg_dset, + genes_path=genes_file, labels_path=labels_path, + output_path=table_file, gene_shape=gene_shape) + ret = luigi.build([t], local_scheduler=True) + if not ret: + raise RuntimeError("Computing gene expressions failed") diff --git a/scripts/attributes/master.py b/scripts/attributes/master.py index ce8a00a5d8fbedf27092e80213e8e1896ad57dec..2f3129ac8f6745bd4309b9f3131cb3f87ca4e464 100644 --- a/scripts/attributes/master.py +++ b/scripts/attributes/master.py @@ -50,7 +50,8 @@ def make_cell_tables(folder, name, tmp_folder, resolution, if not os.path.exists(aux_gene_path): raise RuntimeError("Can't find auxiliary gene file") gene_out = os.path.join(table_folder, 'genes.csv') - write_genes_table(seg_path, aux_gene_path, gene_out, label_ids) + write_genes_table(seg_path, aux_gene_path, gene_out, label_ids, + tmp_folder, target) # make table with morphology morpho_out = os.path.join(table_folder, 'morphology.csv') diff --git a/scripts/extension/__init__.py b/scripts/extension/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/extension/attributes/__init__.py b/scripts/extension/attributes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..330dcad79b573469373772e1ae61236d06dd3f6a --- /dev/null +++ b/scripts/extension/attributes/__init__.py @@ -0,0 +1 @@ +from .genes import GenesLocal, GenesSlurm diff --git a/scripts/extension/attributes/genes.py b/scripts/extension/attributes/genes.py new file mode 100644 index 0000000000000000000000000000000000000000..f34eb7cd12b754e5756901746fa8f74e7cccd7e7 --- /dev/null +++ b/scripts/extension/attributes/genes.py @@ -0,0 +1,190 @@ +#! /bin/python + +import os +import sys +import json +import csv +from concurrent import futures + +import luigi +import numpy as np +from vigra.analysis import extractRegionFeatures +from vigra.sampling import resize + +import cluster_tools.utils.volume_utils as vu +import cluster_tools.utils.function_utils as fu +from cluster_tools.utils.task_utils import DummyTask +from cluster_tools.cluster_tasks import SlurmTask, LocalTask + +# +# Gene Attribute Tasks +# + + +class GenesBase(luigi.Task): + """ Genes base class + """ + + task_name = 'genes' + src_file = os.path.abspath(__file__) + allow_retry = False + + # input volumes and graph + segmentation_path = luigi.Parameter() + segmentation_key = luigi.Parameter() + genes_path = luigi.Parameter() + labels_path = luigi.Parameter() + output_path = luigi.Parameter() + gene_shape = luigi.ListParameter() + # + dependency = luigi.TaskParameter(default=DummyTask()) + + def requires(self): + return self.dependency + + def run_impl(self): + # get the global config and init configs + shebang = self.global_config_values()[0] + self.init(shebang) + + # load the task config + config = self.get_task_config() + + # update the config with input and graph paths and keys + # as well as block shape + config.update({'segmentation_path': self.segmentation_path, + 'segmentation_key': self.segmentation_key, + 'genes_path': self.genes_path, + 'output_path': self.output_path, + 'labels_path': self.labels_path, + 'gene_shape': self.gene_shape}) + + # prime and run the job + self.prepare_jobs(1, None, config) + self.submit_jobs(1) + + # wait till jobs finish and check for job success + self.wait_for_jobs() + self.check_jobs(1) + + +class GenesLocal(GenesBase, LocalTask): + """ Genes on local machine + """ + pass + + +class GenesSlurm(GenesBase, SlurmTask): + """ Genes on slurm cluster + """ + pass + + +# +# Implementation +# + + +def get_sizes_and_bbs(data): + # compute the relevant vigra region features + features = extractRegionFeatures(data.astype('float32'), data.astype('uint32'), + features=['Coord<Maximum >', 'Coord<Minimum >', 'Count']) + + # extract sizes from features + cell_sizes = features['Count'].squeeze().astype('uint64') + + # compute bounding boxes from features + mins = features['Coord<Minimum >'].astype('uint32') + maxs = features['Coord<Maximum >'].astype('uint32') + 1 + cell_bbs = [tuple(slice(mi, ma) for mi, ma in zip(min_, max_)) + for min_, max_ in zip(mins, maxs)] + return cell_sizes, cell_bbs + + +def get_cell_expression(segmentation, all_genes, n_threads): + num_genes = all_genes.shape[0] + # NOTE we need to recalculate the unique labels here, beacause we might not + # have all labels due to donwsampling + labels = np.unique(segmentation) + cells_expression = np.zeros((len(labels), num_genes), dtype='float32') + cell_sizes, cell_bbs = get_sizes_and_bbs(segmentation) + + def compute_expressions(cell_idx, cell_label): + # get size and boundinng box of this cell + cell_size = cell_sizes[cell_label] + bb = cell_bbs[cell_label] + # get the cell mask and the gene expression in bounding box + cell_masked = segmentation[bb] == cell_label + genes_in_cell = all_genes[(slice(None),) + bb] + # accumulate the gene expression channels over the cell mask + gene_expr_sum = np.sum(genes_in_cell[:, cell_masked] > 0, axis=1) + # divide by the cell size and write result + cells_expression[cell_idx] = gene_expr_sum / cell_size + + with futures.ThreadPoolExecutor(n_threads) as tp: + tasks = [tp.submit(compute_expressions, cell_idx, cell_label) + for cell_idx, cell_label in enumerate(labels) if cell_label != 0] + [t.result() for t in tasks] + return labels, cells_expression + + +def write_genes_table(output_path, expression, gene_names, labels, avail_labels): + n_labels = len(labels) + n_cols = len(gene_names) + 1 + + data = np.zeros((n_labels, n_cols), dtype='float32') + data[:, 0] = labels + data[avail_labels, 1:] = expression + + col_names = ['label_id'] + gene_names + assert data.shape[1] == len(col_names) + with open(output_path, 'w') as f: + csv_writer = csv.writer(f, delimiter='\t') + csv_writer.writerow(col_names) + csv_writer.writerows(data) + + +def genes(job_id, config_path): + + fu.log("start processing job %i" % job_id) + fu.log("reading config from %s" % config_path) + + # get the config + with open(config_path) as f: + config = json.load(f) + segmentation_path = config['segmentation_path'] + segmentation_key = config['segmentation_key'] + genes_path = config['genes_path'] + labels_path = config['labels_path'] + output_path = config['output_path'] + gene_shape = tuple(config['gene_shape']) + n_threads = config.get('threads_per_job', 1) + + fu.log("Loading segmentation, labels and gene-data") + # load segmentation, labels and genes + with vu.file_reader(segmentation_path, 'r') as f: + segmentation = f[segmentation_key][:] + labels = np.load(labels_path) + + genes_dset = 'genes' + names_dset = 'gene_names' + with vu.file_reader(genes_path, 'r') as f: + all_genes = f[genes_dset][:] + gene_names = [i.decode('utf-8') for i in f[names_dset]] + + # resize the segmentation to gene space + segmentation = resize(segmentation.astype("float32"), + shape=gene_shape, order=0).astype('uint16') + fu.log("Compute gene expression") + avail_labels, expression = get_cell_expression(segmentation, all_genes, n_threads) + + fu.log('Save results to %s' % output_path) + write_genes_table(output_path, expression, gene_names, labels, avail_labels) + fu.log_job_success(job_id) + + +if __name__ == '__main__': + path = sys.argv[1] + assert os.path.exists(path), path + job_id = int(os.path.split(path)[1].split('.')[0].split('_')[-1]) + genes(job_id, path) diff --git a/test/attributes/test_genes.py b/test/attributes/test_genes.py index 8f34910deaf4b613457b2ebe8467745bdd20a3d5..ff8fe8004ffe70cc43b6d3125725dd92d70937aa 100644 --- a/test/attributes/test_genes.py +++ b/test/attributes/test_genes.py @@ -1,17 +1,20 @@ import unittest -import sys import os +import json +import sys import numpy as np +from shutil import rmtree sys.path.append('../..') # check new version of gene mapping against original class TestGeneAttributes(unittest.TestCase): - test_file = 'test_table.csv' + tmp_folder = 'tmp_genes' + test_file = 'tmp_genes/test_table.csv' - def tearDown(self): + def _tearDown(self): try: - os.remove(self.test_file) + rmtree(self.tmp_folder) except OSError: pass @@ -22,6 +25,7 @@ class TestGeneAttributes(unittest.TestCase): def test_genes(self): from scripts.attributes.genes import write_genes_table + from scripts.extension.attributes import GenesLocal from scripts.files import get_h5_path_from_xml # load original genes table @@ -35,8 +39,19 @@ class TestGeneAttributes(unittest.TestCase): genes_file = '../../data/0.0.0/misc/meds_all_genes.xml' genes_file = get_h5_path_from_xml(genes_file) table_file = self.test_file + + # write the global config + config_folder = os.path.join(self.tmp_folder, 'configs') + os.makedirs(config_folder, exist_ok=True) + conf = GenesLocal.default_global_config() + shebang = '#! /g/kreshuk/pape/Work/software/conda/miniconda3/envs/cluster_env37/bin/python' + conf.update({'shebang': shebang}) + with open(os.path.join(config_folder, 'global.config'), 'w') as f: + json.dump(conf, f) + print("Start computation ...") - write_genes_table(segm_file, genes_file, table_file, labels) + write_genes_table(segm_file, genes_file, table_file, labels, + self.tmp_folder, 'local', 8) table = self.load_table(table_file) # make sure new and old table agree