From f3038fb1a32d972bf46d9286a37633a818e34b51 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Wed, 3 Aug 2022 14:58:00 +0200
Subject: [PATCH] Add scripts to benchmark the effect of threads on the query
 mapping performance

---
 benches/data/.gitignore  |  1 +
 benches/data/download.py | 70 +++++++++++++++++++++++++++++
 benches/mapping/bench.py | 71 ++++++++++++++++++++++++++++++
 benches/mapping/plot.py  | 95 ++++++++++++++++++++++++++++++++++++++++
 4 files changed, 237 insertions(+)
 create mode 100644 benches/data/.gitignore
 create mode 100644 benches/data/download.py
 create mode 100644 benches/mapping/bench.py
 create mode 100644 benches/mapping/plot.py

diff --git a/benches/data/.gitignore b/benches/data/.gitignore
new file mode 100644
index 0000000..345cc57
--- /dev/null
+++ b/benches/data/.gitignore
@@ -0,0 +1 @@
+*.fna
diff --git a/benches/data/download.py b/benches/data/download.py
new file mode 100644
index 0000000..1a207bd
--- /dev/null
+++ b/benches/data/download.py
@@ -0,0 +1,70 @@
+#!/bin/sh
+
+import os
+import gzip
+import shutil
+import urllib.request
+
+import rich.progress
+
+SAMPLES = [
+    "76859.SAMN03263149",
+    "76856.SAMN03263147",
+    "562982.SAMN02595332",
+    "154046.SAMEA3545258",
+    "322095.SAMN03854412",
+    "596327.SAMN00002220",
+    "28131.SAMD00034934",
+    "1123263.SAMN02441153",
+    "33039.SAMEA3545241",
+    "1512.SAMEA3545330",
+    "926561.SAMN02261313",
+    "1121279.SAMN02745887",
+    "1216006.SAMEA4552037",
+    "715226.SAMN02469919",
+    "1121935.SAMN02440417",
+    "1210082.SAMD00040656",
+    "1189612.SAMN02469969",
+    "518637.SAMN00008825",
+    "585506.SAMN00139194",
+    "82977.SAMN02928654",
+    "1267533.SAMN02440487",
+    "373.SAMN04223557",
+    "298701.SAMN02471586",
+    "1852381.SAMEA4555866",
+    "1341181.SAMN02470908",
+    "1194088.SAMN02470121",
+    "198092.SAMN02745194",
+    "1561204.SAMN03107784",
+    "570952.SAMN02441250",
+    "419481.SAMN05216233",
+    "888845.SAMN05017598",
+    "1254432.SAMN02603858",
+    "1909395.SAMN05912833",
+    "448385.SAMEA3138271",
+    "1343740.SAMN02604192",
+    "48.SAMN03704078",
+    "1391654.SAMN03951129",
+    "1912.SAMN05935554",
+    "576784.SAMEA2519540",
+    "749414.SAMN02603683",
+    "722472.SAMN05444171",
+    "1173028.SAMN02261253",
+    "1134687.SAMN02463898",
+    "398579.SAMN02598386",
+    "54388.SAMEA2645827",
+    "1773.SAMN06010452",
+    "1281227.SAMN01885939",
+    "645465.SAMN02595349",
+    "36809.SAMN04572937",
+    "649639.SAMN00007429",
+]
+
+data_folder = os.path.dirname(os.path.realpath(__file__))
+for sample in rich.progress.track(SAMPLES, description="Downloading..."):
+    tax_id = sample.split(".")[0]
+    url = "https://progenomes.embl.de/dumpSequence.cgi?p={}&t=c&a={}".format(sample, tax_id)
+    with urllib.request.urlopen(url) as res:
+        with gzip.open(res) as src:
+            with open(os.path.join(data_folder, "{}.fna".format(sample)), "wb") as dst:
+                shutil.copyfileobj(src, dst)
diff --git a/benches/mapping/bench.py b/benches/mapping/bench.py
new file mode 100644
index 0000000..a9995ad
--- /dev/null
+++ b/benches/mapping/bench.py
@@ -0,0 +1,71 @@
+import argparse
+import glob
+import json
+import os
+import time
+import statistics
+import sys
+import warnings
+from itertools import islice
+
+sys.path.insert(0, os.path.realpath(os.path.join(__file__, "..", "..")))
+
+import Bio.SeqIO
+import numpy
+import pandas
+import rich.progress
+from pyfastani import Sketch
+
+parser = argparse.ArgumentParser()
+parser.add_argument("-r", "--runs", default=3, type=int)
+parser.add_argument("-d", "--data", required=True)
+parser.add_argument("-o", "--output", required=True)
+parser.add_argument("-j", "--jobs", default=os.cpu_count() or 1, type=int)
+args = parser.parse_args()
+
+genomes = [
+    list(Bio.SeqIO.parse(filename, "fasta"))
+    for filename in (glob.glob("data/*.fna") + glob.glob("vendor/FastANI/data/*.fna"))
+]
+
+with rich.progress.Progress(transient=True) as progress:
+    warnings.showwarning = lambda msg, c, f, l, file=None, line=None: progress.print(msg)
+
+    sketch = Sketch()
+    task = progress.add_task(total=len(genomes), description="Sketching...")
+    for genome in progress.track(genomes, task_id=task):
+        sketch.add_draft(genome[0].id, [ bytes(contig.seq) for contig in genome ])
+    progress.remove_task(task_id=task)
+
+    mapper = sketch.index()
+
+    results = dict(results=[])
+    task = progress.add_task(total=len(genomes), description="Querying...")
+    for genome in progress.track(genomes, task_id=task):
+        contigs = [ bytes(contig.seq) for contig in genome ]
+        task2 = progress.add_task(total=args.jobs, description="Threads...")
+        for thread_count in progress.track(range(1, args.jobs+1), task_id=task2):
+            times = []
+            task3 = progress.add_task(total=args.runs, description="Repeat...")
+            for run in progress.track(range(args.runs), task_id=task3):
+                t1 = time.time()
+                hits = mapper.query_draft(contigs, threads=thread_count)
+                t2 = time.time()
+                times.append(t2 - t1)
+            progress.remove_task(task3)
+            results["results"].append({
+                "genome": genome[0].id,
+                "threads": thread_count,
+                "nucleotides": sum(map(len, contigs)),
+                "times": times,
+                "mean": statistics.mean(times),
+                "stddev": statistics.stdev(times),
+                "median": statistics.median(times),
+                "min": min(times),
+                "max": max(times),
+            })
+        progress.remove_task(task_id=task2)
+    progress.remove_task(task_id=task)
+
+with open(args.output, "w") as f:
+    json.dump(results, f, sort_keys=True, indent=4)
diff --git a/benches/mapping/plot.py b/benches/mapping/plot.py
new file mode 100644
index 0000000..c8da6f4
--- /dev/null
+++ b/benches/mapping/plot.py
@@ -0,0 +1,95 @@
+import argparse
+import itertools
+import json
+import os
+import re
+import math
+
+import numpy
+import matplotlib.pyplot as plt
+import scipy.stats
+from scipy.optimize import curve_fit
+from numpy.polynomial import Polynomial
+from palettable.cartocolors.qualitative import Bold_9
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("-i", "--input", required=True)
+parser.add_argument("-o", "--output")
+parser.add_argument("-s", "--show", action="store_true")
+args = parser.parse_args()
+
+def exp_decay(x, a, b, c):
+    return a*numpy.exp(-b*x) + c
+
+with open(args.input) as f:
+    data = json.load(f)
+
+plt.figure(1, figsize=(12, 6))
+
+plt.subplot(1, 2, 1)
+data["results"].sort(key=lambda r: (r["threads"], r["nucleotides"]))
+for color, (threads, group) in zip(
+    itertools.cycle(Bold_9.hex_colors),
+    itertools.groupby(data["results"], key=lambda r: r["threads"])
+):
+    group = list(group)
+    X = numpy.array([r["nucleotides"] / 1e6 for r in group])
+    Y = numpy.array([r["mean"] for r in group])
+
+    # p = Polynomial.fit(X, Y, 2)
+    # pX = numpy.linspace(0, max(r["sequences"] for r in group), 1000)
+    # reg = scipy.stats.linregress(X, Y)
+    # plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="")
+    # ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
+    plt.scatter(X, Y, marker="+", color=color, label=f"threads={threads}")
+    # plt.plot(pX, p(pX), color=color, linestyle="--")
+
+plt.legend()
+plt.xlabel("Genome size (Mbp)")
+plt.ylabel("Query Time (s)")
+
+plt.subplot(1, 2, 2)
+data["results"].sort(key=lambda r: (r["nucleotides"], r["threads"]))
+for color, (nucleotides, group) in zip(
+    itertools.cycle(Bold_9.hex_colors),
+    itertools.groupby(data["results"], key=lambda r: r["nucleotides"])
+):
+    group = list(group)
+    X = numpy.array([r["threads"] for r in group])
+    Y = numpy.array([r["mean"] for r in group])
+
+    popt, pcov = curve_fit(exp_decay, X, Y)
+    pX = numpy.linspace(1, max(r["threads"] for r in group), 100)
+
+    plt.scatter(X, Y, marker="+", color=color, label=f"{group[0]['genome']} ({nucleotides/1e6:.1f} Mbp)")
+    plt.plot(pX, exp_decay(pX, *popt), color=color, linestyle="--")
+
+plt.legend()
+plt.xlabel("Threads")
+plt.ylabel("Query Time (s)")
+
+# plt.subplot(1, 2, 2)
+# data["results"].sort(key=lambda r: (r["backend"], r["residues"]))
+# for color, (backend, group) in zip(
+#     Bold_4.hex_colors, itertools.groupby(data["results"], key=lambda r: r["backend"])
+# ):
+#     group = list(group)
+#     X = numpy.array([r["residues"] for r in group])
+#     Y = numpy.array([r["mean"] for r in group])
+#     p = Polynomial.fit(X, Y, 2)
+#     # reg = scipy.stats.linregress(X, Y)
+#     # plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="")
+#     # ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
+#     plt.scatter(X, Y, marker="+", color=color, label=f"{backend}")
+#     plt.plot(X, p(X), color=color, linestyle="--")
+#
+# plt.legend()
+# plt.xlabel("Number of residues")
+# plt.ylabel("Time (s)")
+
+plt.tight_layout()
+output = args.output or args.input.replace(".json", ".svg")
+plt.savefig(output, transparent=True)
+if args.show:
+    plt.show()
-- 
GitLab