Commit 172297d4 authored by Martin Larralde's avatar Martin Larralde
Browse files

Update plot script in `connection_scoring` benchmark

parent 9666df6e
......@@ -4,6 +4,9 @@ import glob
import time
import statistics
import json
import sys
sys.path.append(os.path.realpath(os.path.join(__file__, "..", "..", "..")))
import tqdm
......@@ -45,7 +48,7 @@ for filename in tqdm.tqdm(glob.glob(os.path.join(args.data, "*.fna"))):
# run connection scoring
for backend in ["avx", "sse", "generic", None]:
times = []
for run in tqdm.tqdm(range(args.runs), description=str(backend), leave=False):
for run in tqdm.tqdm(range(args.runs), desc=str(backend), leave=False):
# initialize scorer
scorer = ConnectionScorer(backend=backend)
scorer_nodes = nodes.copy()
......
......@@ -8,7 +8,7 @@ import math
import numpy
import matplotlib.pyplot as plt
import scipy.stats
from palettable.cartocolors.qualitative import Bold_3
from palettable.cartocolors.qualitative import Bold_4
plt.rcParams["svg.fonttype"] = "none"
......@@ -25,6 +25,8 @@ with open(args.input) as f:
for result in data["results"]:
if result["backend"] is None:
result["backend"] = "None"
elif result["backend"] == "generic":
result["backend"] = "Generic"
else:
result["backend"] = result["backend"].upper()
......@@ -33,15 +35,15 @@ plt.figure(1, figsize=(12, 6))
plt.subplot(1, 2, 1)
data["results"].sort(key=lambda r: (r["backend"], r["node_count"]))
for color, (backend, group) in zip(
Bold_3.hex_colors, itertools.groupby(data["results"], key=lambda r: r["backend"])
Bold_4.hex_colors, itertools.groupby(data["results"], key=lambda r: r["backend"])
):
group = list(group)
X = numpy.array([r["node_count"] for r in group])
Y = numpy.array([r["mean"] for r in group])
reg = scipy.stats.linregress(X, Y)
plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="")
ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
plt.errorbar(X, Y, ci, linestyle='', marker="+", color=color, elinewidth=0.3, ecolor='black', label=f"{backend} (R={reg.rvalue:.3f})")
# ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
plt.scatter(X, Y, marker="+", color=color, label=f"{backend} (R={reg.rvalue:.3f})")
plt.legend()
plt.xlabel("Node count")
plt.ylabel("Time (s)")
......@@ -50,15 +52,15 @@ plt.ylabel("Time (s)")
plt.subplot(1, 2, 2)
data["results"].sort(key=lambda r: (r["backend"], r["nucleotide_count"]))
for color, (backend, group) in zip(
Bold_3.hex_colors, itertools.groupby(data["results"], key=lambda r: r["backend"])
Bold_4.hex_colors, itertools.groupby(data["results"], key=lambda r: r["backend"])
):
group = list(group)
X = numpy.array([r["nucleotide_count"] / 1_000_000 for r in group])
Y = numpy.array([r["mean"] for r in group])
reg = scipy.stats.linregress(X, Y)
plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="")
ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
plt.errorbar(X, Y, ci, linestyle='', marker="+", color=color, elinewidth=0.3, ecolor='black', label=f"{backend} (R={reg.rvalue:.3f})")
# ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
plt.scatter(X, Y, marker="+", color=color, label=f"{backend} (R={reg.rvalue:.3f})")
plt.legend()
plt.xlabel("Nucleotide count (Mbp)")
plt.ylabel("Time (s)")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment