Skip to content
Snippets Groups Projects
Commit f3038fb1 authored by Martin Larralde's avatar Martin Larralde
Browse files

Add scripts to benchmark the effect of threads on the query mapping performance

parent 574ea624
No related branches found
No related tags found
No related merge requests found
*.fna
#!/bin/sh
import os
import gzip
import shutil
import urllib.request
import rich.progress
SAMPLES = [
"76859.SAMN03263149",
"76856.SAMN03263147",
"562982.SAMN02595332",
"154046.SAMEA3545258",
"322095.SAMN03854412",
"596327.SAMN00002220",
"28131.SAMD00034934",
"1123263.SAMN02441153",
"33039.SAMEA3545241",
"1512.SAMEA3545330",
"926561.SAMN02261313",
"1121279.SAMN02745887",
"1216006.SAMEA4552037",
"715226.SAMN02469919",
"1121935.SAMN02440417",
"1210082.SAMD00040656",
"1189612.SAMN02469969",
"518637.SAMN00008825",
"585506.SAMN00139194",
"82977.SAMN02928654",
"1267533.SAMN02440487",
"373.SAMN04223557",
"298701.SAMN02471586",
"1852381.SAMEA4555866",
"1341181.SAMN02470908",
"1194088.SAMN02470121",
"198092.SAMN02745194",
"1561204.SAMN03107784",
"570952.SAMN02441250",
"419481.SAMN05216233",
"888845.SAMN05017598",
"1254432.SAMN02603858",
"1909395.SAMN05912833",
"448385.SAMEA3138271",
"1343740.SAMN02604192",
"48.SAMN03704078",
"1391654.SAMN03951129",
"1912.SAMN05935554",
"576784.SAMEA2519540",
"749414.SAMN02603683",
"722472.SAMN05444171",
"1173028.SAMN02261253",
"1134687.SAMN02463898",
"398579.SAMN02598386",
"54388.SAMEA2645827",
"1773.SAMN06010452",
"1281227.SAMN01885939",
"645465.SAMN02595349",
"36809.SAMN04572937",
"649639.SAMN00007429",
]
data_folder = os.path.dirname(os.path.realpath(__file__))
for sample in rich.progress.track(SAMPLES, description="Downloading..."):
tax_id = sample.split(".")[0]
url = "https://progenomes.embl.de/dumpSequence.cgi?p={}&t=c&a={}".format(sample, tax_id)
with urllib.request.urlopen(url) as res:
with gzip.open(res) as src:
with open(os.path.join(data_folder, "{}.fna".format(sample)), "wb") as dst:
shutil.copyfileobj(src, dst)
import argparse
import glob
import json
import os
import time
import statistics
import sys
import warnings
from itertools import islice
sys.path.insert(0, os.path.realpath(os.path.join(__file__, "..", "..")))
import Bio.SeqIO
import numpy
import pandas
import rich.progress
from pyfastani import Sketch
parser = argparse.ArgumentParser()
parser.add_argument("-r", "--runs", default=3, type=int)
parser.add_argument("-d", "--data", required=True)
parser.add_argument("-o", "--output", required=True)
parser.add_argument("-j", "--jobs", default=os.cpu_count() or 1, type=int)
args = parser.parse_args()
genomes = [
list(Bio.SeqIO.parse(filename, "fasta"))
for filename in (glob.glob("data/*.fna") + glob.glob("vendor/FastANI/data/*.fna"))
]
with rich.progress.Progress(transient=True) as progress:
warnings.showwarning = lambda msg, c, f, l, file=None, line=None: progress.print(msg)
sketch = Sketch()
task = progress.add_task(total=len(genomes), description="Sketching...")
for genome in progress.track(genomes, task_id=task):
sketch.add_draft(genome[0].id, [ bytes(contig.seq) for contig in genome ])
progress.remove_task(task_id=task)
mapper = sketch.index()
results = dict(results=[])
task = progress.add_task(total=len(genomes), description="Querying...")
for genome in progress.track(genomes, task_id=task):
contigs = [ bytes(contig.seq) for contig in genome ]
task2 = progress.add_task(total=args.jobs, description="Threads...")
for thread_count in progress.track(range(1, args.jobs+1), task_id=task2):
times = []
task3 = progress.add_task(total=args.runs, description="Repeat...")
for run in progress.track(range(args.runs), task_id=task3):
t1 = time.time()
hits = mapper.query_draft(contigs, threads=thread_count)
t2 = time.time()
times.append(t2 - t1)
progress.remove_task(task3)
results["results"].append({
"genome": genome[0].id,
"threads": thread_count,
"nucleotides": sum(map(len, contigs)),
"times": times,
"mean": statistics.mean(times),
"stddev": statistics.stdev(times),
"median": statistics.median(times),
"min": min(times),
"max": max(times),
})
progress.remove_task(task_id=task2)
progress.remove_task(task_id=task)
with open(args.output, "w") as f:
json.dump(results, f, sort_keys=True, indent=4)
import argparse
import itertools
import json
import os
import re
import math
import numpy
import matplotlib.pyplot as plt
import scipy.stats
from scipy.optimize import curve_fit
from numpy.polynomial import Polynomial
from palettable.cartocolors.qualitative import Bold_9
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", required=True)
parser.add_argument("-o", "--output")
parser.add_argument("-s", "--show", action="store_true")
args = parser.parse_args()
def exp_decay(x, a, b, c):
return a*numpy.exp(-b*x) + c
with open(args.input) as f:
data = json.load(f)
plt.figure(1, figsize=(12, 6))
plt.subplot(1, 2, 1)
data["results"].sort(key=lambda r: (r["threads"], r["nucleotides"]))
for color, (threads, group) in zip(
itertools.cycle(Bold_9.hex_colors),
itertools.groupby(data["results"], key=lambda r: r["threads"])
):
group = list(group)
X = numpy.array([r["nucleotides"] / 1e6 for r in group])
Y = numpy.array([r["mean"] for r in group])
# p = Polynomial.fit(X, Y, 2)
# pX = numpy.linspace(0, max(r["sequences"] for r in group), 1000)
# reg = scipy.stats.linregress(X, Y)
# plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="")
# ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
plt.scatter(X, Y, marker="+", color=color, label=f"threads={threads}")
# plt.plot(pX, p(pX), color=color, linestyle="--")
plt.legend()
plt.xlabel("Genome size (Mbp)")
plt.ylabel("Query Time (s)")
plt.subplot(1, 2, 2)
data["results"].sort(key=lambda r: (r["nucleotides"], r["threads"]))
for color, (nucleotides, group) in zip(
itertools.cycle(Bold_9.hex_colors),
itertools.groupby(data["results"], key=lambda r: r["nucleotides"])
):
group = list(group)
X = numpy.array([r["threads"] for r in group])
Y = numpy.array([r["mean"] for r in group])
popt, pcov = curve_fit(exp_decay, X, Y)
pX = numpy.linspace(1, max(r["threads"] for r in group), 100)
plt.scatter(X, Y, marker="+", color=color, label=f"{group[0]['genome']} ({nucleotides/1e6:.1f} Mbp)")
plt.plot(pX, exp_decay(pX, *popt), color=color, linestyle="--")
plt.legend()
plt.xlabel("Threads")
plt.ylabel("Query Time (s)")
# plt.subplot(1, 2, 2)
# data["results"].sort(key=lambda r: (r["backend"], r["residues"]))
# for color, (backend, group) in zip(
# Bold_4.hex_colors, itertools.groupby(data["results"], key=lambda r: r["backend"])
# ):
# group = list(group)
# X = numpy.array([r["residues"] for r in group])
# Y = numpy.array([r["mean"] for r in group])
# p = Polynomial.fit(X, Y, 2)
# # reg = scipy.stats.linregress(X, Y)
# # plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="")
# # ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
# plt.scatter(X, Y, marker="+", color=color, label=f"{backend}")
# plt.plot(X, p(X), color=color, linestyle="--")
#
# plt.legend()
# plt.xlabel("Number of residues")
# plt.ylabel("Time (s)")
plt.tight_layout()
output = args.output or args.input.replace(".json", ".svg")
plt.savefig(output, transparent=True)
if args.show:
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment