From f3038fb1a32d972bf46d9286a37633a818e34b51 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Wed, 3 Aug 2022 14:58:00 +0200 Subject: [PATCH] Add scripts to benchmark the effect of threads on the query mapping performance --- benches/data/.gitignore | 1 + benches/data/download.py | 70 +++++++++++++++++++++++++++++ benches/mapping/bench.py | 71 ++++++++++++++++++++++++++++++ benches/mapping/plot.py | 95 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 237 insertions(+) create mode 100644 benches/data/.gitignore create mode 100644 benches/data/download.py create mode 100644 benches/mapping/bench.py create mode 100644 benches/mapping/plot.py diff --git a/benches/data/.gitignore b/benches/data/.gitignore new file mode 100644 index 0000000..345cc57 --- /dev/null +++ b/benches/data/.gitignore @@ -0,0 +1 @@ +*.fna diff --git a/benches/data/download.py b/benches/data/download.py new file mode 100644 index 0000000..1a207bd --- /dev/null +++ b/benches/data/download.py @@ -0,0 +1,70 @@ +#!/bin/sh + +import os +import gzip +import shutil +import urllib.request + +import rich.progress + +SAMPLES = [ + "76859.SAMN03263149", + "76856.SAMN03263147", + "562982.SAMN02595332", + "154046.SAMEA3545258", + "322095.SAMN03854412", + "596327.SAMN00002220", + "28131.SAMD00034934", + "1123263.SAMN02441153", + "33039.SAMEA3545241", + "1512.SAMEA3545330", + "926561.SAMN02261313", + "1121279.SAMN02745887", + "1216006.SAMEA4552037", + "715226.SAMN02469919", + "1121935.SAMN02440417", + "1210082.SAMD00040656", + "1189612.SAMN02469969", + "518637.SAMN00008825", + "585506.SAMN00139194", + "82977.SAMN02928654", + "1267533.SAMN02440487", + "373.SAMN04223557", + "298701.SAMN02471586", + "1852381.SAMEA4555866", + "1341181.SAMN02470908", + "1194088.SAMN02470121", + "198092.SAMN02745194", + "1561204.SAMN03107784", + "570952.SAMN02441250", + "419481.SAMN05216233", + "888845.SAMN05017598", + "1254432.SAMN02603858", + "1909395.SAMN05912833", + "448385.SAMEA3138271", + "1343740.SAMN02604192", + "48.SAMN03704078", + "1391654.SAMN03951129", + "1912.SAMN05935554", + "576784.SAMEA2519540", + "749414.SAMN02603683", + "722472.SAMN05444171", + "1173028.SAMN02261253", + "1134687.SAMN02463898", + "398579.SAMN02598386", + "54388.SAMEA2645827", + "1773.SAMN06010452", + "1281227.SAMN01885939", + "645465.SAMN02595349", + "36809.SAMN04572937", + "649639.SAMN00007429", +] + +data_folder = os.path.dirname(os.path.realpath(__file__)) +for sample in rich.progress.track(SAMPLES, description="Downloading..."): + tax_id = sample.split(".")[0] + url = "https://progenomes.embl.de/dumpSequence.cgi?p={}&t=c&a={}".format(sample, tax_id) + with urllib.request.urlopen(url) as res: + with gzip.open(res) as src: + with open(os.path.join(data_folder, "{}.fna".format(sample)), "wb") as dst: + shutil.copyfileobj(src, dst) diff --git a/benches/mapping/bench.py b/benches/mapping/bench.py new file mode 100644 index 0000000..a9995ad --- /dev/null +++ b/benches/mapping/bench.py @@ -0,0 +1,71 @@ +import argparse +import glob +import json +import os +import time +import statistics +import sys +import warnings +from itertools import islice + +sys.path.insert(0, os.path.realpath(os.path.join(__file__, "..", ".."))) + +import Bio.SeqIO +import numpy +import pandas +import rich.progress +from pyfastani import Sketch + +parser = argparse.ArgumentParser() +parser.add_argument("-r", "--runs", default=3, type=int) +parser.add_argument("-d", "--data", required=True) +parser.add_argument("-o", "--output", required=True) +parser.add_argument("-j", "--jobs", default=os.cpu_count() or 1, type=int) +args = parser.parse_args() + +genomes = [ + list(Bio.SeqIO.parse(filename, "fasta")) + for filename in (glob.glob("data/*.fna") + glob.glob("vendor/FastANI/data/*.fna")) +] + +with rich.progress.Progress(transient=True) as progress: + warnings.showwarning = lambda msg, c, f, l, file=None, line=None: progress.print(msg) + + sketch = Sketch() + task = progress.add_task(total=len(genomes), description="Sketching...") + for genome in progress.track(genomes, task_id=task): + sketch.add_draft(genome[0].id, [ bytes(contig.seq) for contig in genome ]) + progress.remove_task(task_id=task) + + mapper = sketch.index() + + results = dict(results=[]) + task = progress.add_task(total=len(genomes), description="Querying...") + for genome in progress.track(genomes, task_id=task): + contigs = [ bytes(contig.seq) for contig in genome ] + task2 = progress.add_task(total=args.jobs, description="Threads...") + for thread_count in progress.track(range(1, args.jobs+1), task_id=task2): + times = [] + task3 = progress.add_task(total=args.runs, description="Repeat...") + for run in progress.track(range(args.runs), task_id=task3): + t1 = time.time() + hits = mapper.query_draft(contigs, threads=thread_count) + t2 = time.time() + times.append(t2 - t1) + progress.remove_task(task3) + results["results"].append({ + "genome": genome[0].id, + "threads": thread_count, + "nucleotides": sum(map(len, contigs)), + "times": times, + "mean": statistics.mean(times), + "stddev": statistics.stdev(times), + "median": statistics.median(times), + "min": min(times), + "max": max(times), + }) + progress.remove_task(task_id=task2) + progress.remove_task(task_id=task) + +with open(args.output, "w") as f: + json.dump(results, f, sort_keys=True, indent=4) diff --git a/benches/mapping/plot.py b/benches/mapping/plot.py new file mode 100644 index 0000000..c8da6f4 --- /dev/null +++ b/benches/mapping/plot.py @@ -0,0 +1,95 @@ +import argparse +import itertools +import json +import os +import re +import math + +import numpy +import matplotlib.pyplot as plt +import scipy.stats +from scipy.optimize import curve_fit +from numpy.polynomial import Polynomial +from palettable.cartocolors.qualitative import Bold_9 + + +parser = argparse.ArgumentParser() +parser.add_argument("-i", "--input", required=True) +parser.add_argument("-o", "--output") +parser.add_argument("-s", "--show", action="store_true") +args = parser.parse_args() + +def exp_decay(x, a, b, c): + return a*numpy.exp(-b*x) + c + +with open(args.input) as f: + data = json.load(f) + +plt.figure(1, figsize=(12, 6)) + +plt.subplot(1, 2, 1) +data["results"].sort(key=lambda r: (r["threads"], r["nucleotides"])) +for color, (threads, group) in zip( + itertools.cycle(Bold_9.hex_colors), + itertools.groupby(data["results"], key=lambda r: r["threads"]) +): + group = list(group) + X = numpy.array([r["nucleotides"] / 1e6 for r in group]) + Y = numpy.array([r["mean"] for r in group]) + + # p = Polynomial.fit(X, Y, 2) + # pX = numpy.linspace(0, max(r["sequences"] for r in group), 1000) + # reg = scipy.stats.linregress(X, Y) + # plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="") + # ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group] + plt.scatter(X, Y, marker="+", color=color, label=f"threads={threads}") + # plt.plot(pX, p(pX), color=color, linestyle="--") + +plt.legend() +plt.xlabel("Genome size (Mbp)") +plt.ylabel("Query Time (s)") + +plt.subplot(1, 2, 2) +data["results"].sort(key=lambda r: (r["nucleotides"], r["threads"])) +for color, (nucleotides, group) in zip( + itertools.cycle(Bold_9.hex_colors), + itertools.groupby(data["results"], key=lambda r: r["nucleotides"]) +): + group = list(group) + X = numpy.array([r["threads"] for r in group]) + Y = numpy.array([r["mean"] for r in group]) + + popt, pcov = curve_fit(exp_decay, X, Y) + pX = numpy.linspace(1, max(r["threads"] for r in group), 100) + + plt.scatter(X, Y, marker="+", color=color, label=f"{group[0]['genome']} ({nucleotides/1e6:.1f} Mbp)") + plt.plot(pX, exp_decay(pX, *popt), color=color, linestyle="--") + +plt.legend() +plt.xlabel("Threads") +plt.ylabel("Query Time (s)") + +# plt.subplot(1, 2, 2) +# data["results"].sort(key=lambda r: (r["backend"], r["residues"])) +# for color, (backend, group) in zip( +# Bold_4.hex_colors, itertools.groupby(data["results"], key=lambda r: r["backend"]) +# ): +# group = list(group) +# X = numpy.array([r["residues"] for r in group]) +# Y = numpy.array([r["mean"] for r in group]) +# p = Polynomial.fit(X, Y, 2) +# # reg = scipy.stats.linregress(X, Y) +# # plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="") +# # ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group] +# plt.scatter(X, Y, marker="+", color=color, label=f"{backend}") +# plt.plot(X, p(X), color=color, linestyle="--") +# +# plt.legend() +# plt.xlabel("Number of residues") +# plt.ylabel("Time (s)") + +plt.tight_layout() +output = args.output or args.input.replace(".json", ".svg") +plt.savefig(output, transparent=True) +if args.show: + plt.show() -- GitLab