plot.py 2.34 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import argparse
import itertools
import json
import os
import re
import math

import numpy
import matplotlib.pyplot as plt
import scipy.stats
11
from palettable.cartocolors.qualitative import Bold_4
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

plt.rcParams["svg.fonttype"] = "none"


parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", required=True)
parser.add_argument("-o", "--output")
parser.add_argument("-s", "--show", action="store_true")
args = parser.parse_args()


with open(args.input) as f:
    data = json.load(f)
for result in data["results"]:
    if result["backend"] is None:
        result["backend"] = "None"
28
29
    elif result["backend"] == "generic":
        result["backend"] = "Generic"
30
31
32
33
34
35
36
37
    else:
        result["backend"] = result["backend"].upper()

plt.figure(1, figsize=(12, 6))

plt.subplot(1, 2, 1)
data["results"].sort(key=lambda r: (r["backend"], r["node_count"]))
for color, (backend, group) in zip(
38
    Bold_4.hex_colors, itertools.groupby(data["results"], key=lambda r: r["backend"])
39
40
41
42
43
44
):
    group = list(group)
    X = numpy.array([r["node_count"] for r in group])
    Y = numpy.array([r["mean"] for r in group])
    reg = scipy.stats.linregress(X, Y)
    plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="")
45
46
    # ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
    plt.scatter(X, Y, marker="+", color=color, label=f"{backend} (R={reg.rvalue:.3f})")
47
48
49
50
51
52
53
54
plt.legend()
plt.xlabel("Node count")
plt.ylabel("Time (s)")


plt.subplot(1, 2, 2)
data["results"].sort(key=lambda r: (r["backend"], r["nucleotide_count"]))
for color, (backend, group) in zip(
55
    Bold_4.hex_colors, itertools.groupby(data["results"], key=lambda r: r["backend"])
56
57
58
59
60
61
):
    group = list(group)
    X = numpy.array([r["nucleotide_count"] / 1_000_000 for r in group])
    Y = numpy.array([r["mean"] for r in group])
    reg = scipy.stats.linregress(X, Y)
    plt.plot([ 0, max(X) ], [ reg.intercept, reg.slope*max(X) + reg.intercept ], color=color, linestyle="--", marker="")
62
63
    # ci = [1.96 * r["stddev"] / math.sqrt(len(r["times"])) for r in group]
    plt.scatter(X, Y, marker="+", color=color, label=f"{backend} (R={reg.rvalue:.3f})")
64
65
66
67
68
69
70
71
72
plt.legend()
plt.xlabel("Nucleotide count (Mbp)")
plt.ylabel("Time (s)")


output = args.output or args.input.replace(".json", ".svg")
plt.savefig(output, transparent=True)
if args.show:
    plt.show()