test_connection_scorer.py 6.93 KB
Newer Older
1
2
3
4
5
6
7
import collections.abc
import gzip
import os
import sys
import unittest

from .. import Nodes, Sequence, _pyrodigal
8
from .._pyrodigal import METAGENOMIC_BINS, ConnectionScorer
9
10
11
12
13
14

from .fasta import parse


class TestConnectionScorer(unittest.TestCase):

15
16
17
18
19
20
21
    def assertNodeEqual(self, n1, n2):
        self.assertEqual(n1.index, n2.index)
        self.assertEqual(n1.strand, n2.strand)
        self.assertEqual(n1.type, n2.type)
        self.assertEqual(n1.edge, n2.edge)
        self.assertEqual(n1.gc_bias, n2.gc_bias)
        self.assertEqual(n1.gc_cont, n2.gc_cont)
22
23
24
25
26
        self.assertAlmostEqual(n1.score, n2.score)
        self.assertAlmostEqual(n1.cscore, n2.cscore)
        self.assertAlmostEqual(n1.rscore, n2.rscore)
        self.assertAlmostEqual(n1.sscore, n2.sscore)
        self.assertAlmostEqual(n1.tscore, n2.tscore)
27

28
29
30
31
32
33
34
35
    @classmethod
    def setUpClass(cls):
        data = os.path.realpath(os.path.join(__file__, "..", "data"))
        fna = os.path.join(data, "MIIJ01000039.fna.gz")
        with gzip.open(fna, "rt") as f:
            cls.record = next(parse(f))

    @unittest.skipUnless(_pyrodigal._TARGET_CPU == "x86", "requires x86 CPU")
36
37
    @unittest.skipUnless(_pyrodigal._SSE2_BUILD_SUPPORT, "requires extension compiled with SSE2 support")
    @unittest.skipUnless(_pyrodigal._SSE2_RUNTIME_SUPPORT, "requires machine with SSE2 support")
38
39
40
41
42
    def test_score_connections_sse(self):
        # setup
        seq = Sequence.from_string(self.record.seq)
        tinf = METAGENOMIC_BINS[0].training_info
        scorer_sse = ConnectionScorer(backend="sse")
43
        scorer_none = ConnectionScorer(backend=None)
44
45
        # add nodes from the sequence
        nodes = Nodes()
46
        nodes.extract(seq, translation_table=tinf.translation_table)
47
48
49
        nodes.sort()
        # index nodes for the scorers
        scorer_sse.index(nodes)
50
        scorer_none.index(nodes)
51
52
        # use copies to compute both scores
        nodes_sse = nodes.copy()
53
        nodes_none = nodes.copy()
54
55
56
57
        for i in range(500, len(nodes)):
            # compute boundary
            j = 0 if i < 500 else i - 500
            # score connections without fast-indexing skippable nodes
58
59
            scorer_none.compute_skippable(j, i)
            scorer_none.score_connections(nodes_none, j, i, tinf, final=True)
60
61
62
63
            # compute skippable nodes with SSE and score connections with
            scorer_sse.compute_skippable(j, i)
            scorer_sse.score_connections(nodes_sse, j, i, tinf, final=True)
        # check that both methods scored the same
64
65
        for n1, n2 in zip(nodes_sse, nodes_none):
            self.assertNodeEqual(n1, n2)
66
67

    @unittest.skipUnless(_pyrodigal._TARGET_CPU == "x86", "requires x86 CPU")
68
69
    @unittest.skipUnless(_pyrodigal._AVX2_BUILD_SUPPORT, "requires extension compiled with AVX2 support")
    @unittest.skipUnless(_pyrodigal._AVX2_RUNTIME_SUPPORT, "requires machine with AVX2 support")
70
71
72
73
74
    def test_score_connections_avx(self):
        # setup
        seq = Sequence.from_string(self.record.seq)
        tinf = METAGENOMIC_BINS[0].training_info
        scorer_avx = ConnectionScorer(backend="avx")
75
        scorer_none = ConnectionScorer(backend=None)
76
77
        # add nodes from the sequence
        nodes = Nodes()
78
        nodes.extract(seq, translation_table=tinf.translation_table)
79
80
81
        nodes.sort()
        # index nodes for the scorers
        scorer_avx.index(nodes)
82
        scorer_none.index(nodes)
83
84
        # use copies to compute both scores
        nodes_avx = nodes.copy()
85
        nodes_none = nodes.copy()
86
87
88
89
        for i in range(500, len(nodes)):
            # compute boundary
            j = 0 if i < 500 else i - 500
            # score connections without fast-indexing skippable nodes
90
91
            scorer_none.compute_skippable(j, i)
            scorer_none.score_connections(nodes_none, j, i, tinf, final=True)
92
93
94
95
            # compute skippable nodes with SSE and score connections with
            scorer_avx.compute_skippable(j, i)
            scorer_avx.score_connections(nodes_avx, j, i, tinf, final=True)
        # check that both methods scored the same
96
97
        for n1, n2 in zip(nodes_avx, nodes_none):
            self.assertNodeEqual(n1, n2)
98
99
100
101
102
103
104
105
106

    @unittest.skipUnless(_pyrodigal._TARGET_CPU in ("arm", "aarch64"), "requires ARM CPU")
    @unittest.skipUnless(_pyrodigal._NEON_BUILD_SUPPORT, "requires extension compiled with NEON support")
    @unittest.skipUnless(_pyrodigal._NEON_RUNTIME_SUPPORT, "requires machine with NEON support")
    def test_score_connections_neon(self):
        # setup
        seq = Sequence.from_string(self.record.seq)
        tinf = METAGENOMIC_BINS[0].training_info
        scorer_avx = ConnectionScorer(backend="neon")
107
        scorer_none = ConnectionScorer(backend=None)
108
109
        # add nodes from the sequence
        nodes = Nodes()
110
        nodes.extract(seq, translation_table=tinf.translation_table)
111
112
113
        nodes.sort()
        # index nodes for the scorers
        scorer_avx.index(nodes)
114
        scorer_none.index(nodes)
115
116
        # use copies to compute both scores
        nodes_avx = nodes.copy()
117
        nodes_none = nodes.copy()
118
119
120
121
        for i in range(500, len(nodes)):
            # compute boundary
            j = 0 if i < 500 else i - 500
            # score connections without fast-indexing skippable nodes
122
123
            scorer_none.compute_skippable(j, i)
            scorer_none.score_connections(nodes_none, j, i, tinf, final=True)
124
125
126
127
            # compute skippable nodes with SSE and score connections with
            scorer_avx.compute_skippable(j, i)
            scorer_avx.score_connections(nodes_avx, j, i, tinf, final=True)
        # check that both methods scored the same
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        for n1, n2 in zip(nodes_avx, nodes_none):
            self.assertNodeEqual(n1, n2)

    def test_score_connections_generic(self):
        # setup
        seq = Sequence.from_string(self.record.seq)
        tinf = METAGENOMIC_BINS[0].training_info
        scorer_generic = ConnectionScorer(backend="generic")
        scorer_none = ConnectionScorer(backend=None)
        # add nodes from the sequence
        nodes = Nodes()
        nodes.extract(seq, translation_table=tinf.translation_table)
        nodes.sort()
        # index nodes for the scorers
        scorer_generic.index(nodes)
        scorer_none.index(nodes)
        # use copies to compute both scores
        nodes_generic = nodes.copy()
        nodes_none = nodes.copy()
        for i in range(500, len(nodes)):
            # compute boundary
            j = 0 if i < 500 else i - 500
            # score connections without fast-indexing skippable nodes
            scorer_none.compute_skippable(j, i)
            scorer_none.score_connections(nodes_none, j, i, tinf, final=True)
            # compute skippable nodes with generic filter and score connections
            scorer_generic.compute_skippable(j, i)
            scorer_generic.score_connections(nodes_generic, j, i, tinf, final=True)
        # check that both methods scored the same
        for n1, n2 in zip(nodes_generic, nodes_none):
            self.assertNodeEqual(n1, n2)