_pyrodigal.pyx 174 KB
Newer Older
1
# coding: utf-8
2
# cython: language_level=3, linetrace=True
3

4
"""Bindings to Prodigal, an ORF finder for genomes and metagenomes.
5
6

Example:
7
8
9
10
    Pyrodigal can work on any DNA sequence stored in either a text or a
    byte array. To load a sequence from one of the common sequence formats,
    you can use an external dedicated library such as
    `Biopython <https://github.com/biopython/biopython>`_::
11
12
13

        >>> import gzip
        >>> import Bio.SeqIO
14
        >>> with gzip.open("pyrodigal/tests/data/KK037166.fna.gz", "rt") as f:
15
16
17
18
19
20
21
22
        ...     record = Bio.SeqIO.read(f, "fasta")

    Then use Pyrodigal to find the genes in *metagenomic* mode (without
    training first), and then build a map of codon frequencies for each
    gene::

        >>> from collections import Counter
        >>> import pyrodigal
23
        >>> p = pyrodigal.OrfFinder(meta=True)
24
25
        >>> for gene in p.find_genes(record.seq.encode()):
        ...     gene_seq = gene.sequence()
26
27
28
29
30
31
32
33
34
        ...     codon_counter = Counter()
        ...     for i in range(len(gene_seq), 3):
        ...         codon_counter[gene_seq[i:i+3]] += 1
        ...     codon_frequencies = {
        ...         codon:count/(len(gene_seq)//3)
        ...         for codon, count in codon_counter.items()
        ...     }

Caution:
35
36
37
38
    In Pyrodigal, sequences are assumed to contain only the usual
    nucleotides (A/T/G/C) as lowercase or uppercase letters; any other
    symbol will be treated as an unknown nucleotide. Be careful to remove
    the gap characters if loading sequences from a multiple alignment file.
39

40
41
References:
    - Hyatt D, Chen GL, Locascio PF, Land ML, Larimer FW, Hauser LJ.
42
43
      *Prodigal: prokaryotic gene recognition and translation initiation
      site identification.* BMC Bioinformatics. 2010 Mar 8;11:119.
44
      doi:10.1186/1471-2105-11-119. PMID:20211023. PMCID:PMC2848648.
45

46
"""
47

48
49
# ----------------------------------------------------------------------------

50
from cpython.buffer cimport PyBUF_READ, PyBUF_WRITE
51
from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AsString
52
53
from cpython.exc cimport PyErr_CheckSignals
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
54
from cpython.memoryview cimport PyMemoryView_FromMemory
55
56
from cpython.ref cimport Py_INCREF
from cpython.tuple cimport PyTuple_New, PyTuple_SET_ITEM
57
from libc.math cimport sqrt, log, pow, fmax, fmin
58
from libc.stdint cimport int8_t, uint8_t, uintptr_t
59
from libc.stdio cimport printf
60
from libc.stdlib cimport abs, malloc, calloc, free, qsort
61
from libc.string cimport memcpy, memchr, memset, strstr
62
63
64

from pyrodigal.prodigal cimport bitmap, dprog, gene, node, sequence
from pyrodigal.prodigal.metagenomic cimport NUM_META, _metagenomic_bin, initialize_metagenomic_bins
65
from pyrodigal.prodigal.node cimport _motif, _node, MIN_EDGE_GENE, MIN_GENE, MAX_SAM_OVLP, cross_mask, compare_nodes, stopcmp_nodes
66
from pyrodigal.prodigal.sequence cimport _mask, node_type, rcom_seq
67
68
from pyrodigal.prodigal.training cimport _training
from pyrodigal._unicode cimport *
69
70
71
72
73
74
75
76
77
78
from pyrodigal._sequence cimport (
    nucleotide,
    _is_a,
    _is_g,
    _is_gc,
    _is_stop,
    _is_start,
    _is_atg,
    _is_ttg,
    _is_gtg,
79
    _mer_ndx,
80
81
82
    _letters,
    _complement
)
83
from pyrodigal.impl.generic cimport skippable_generic
84

85
86
87
88
89
90
IF TARGET_CPU == "x86":
    from pyrodigal.cpu_features.x86 cimport GetX86Info, X86Info
    IF SSE2_BUILD_SUPPORT:
        from pyrodigal.impl.sse cimport skippable_sse
    IF AVX2_BUILD_SUPPORT:
        from pyrodigal.impl.avx cimport skippable_avx
91
ELIF TARGET_CPU == "arm" or TARGET_CPU == "aarch64":
92
93
    IF TARGET_CPU == "arm":
        from pyrodigal.cpu_features.arm cimport GetArmInfo, ArmInfo
94
95
    IF NEON_BUILD_SUPPORT:
        from pyrodigal.impl.neon cimport skippable_neon
96

97
IF SYS_IMPLEMENTATION_NAME == "pypy":
98
99
    cdef int MVIEW_READ  = PyBUF_READ | PyBUF_WRITE
    cdef int MVIEW_WRITE = PyBUF_READ | PyBUF_WRITE
100
ELSE:
101
102
    cdef int MVIEW_READ  = PyBUF_READ
    cdef int MVIEW_WRITE = PyBUF_WRITE
103

104
105
106
# ----------------------------------------------------------------------------

import warnings
107
import textwrap
108
109
import threading

110
include "_version.py"
111

112
113
# --- Module-level constants -------------------------------------------------

114
115
cdef int    IDEAL_SINGLE_GENOME = 100000
cdef int    MIN_SINGLE_GENOME   = 20000
116
cdef int    WINDOW              = 120
117
cdef size_t MIN_MASKS_ALLOC     = 8
118
119
cdef size_t MIN_GENES_ALLOC     = 8
cdef size_t MIN_NODES_ALLOC     = 8 * MIN_GENES_ALLOC
120
cdef set    TRANSLATION_TABLES  = set(range(1, 7)) | set(range(9, 17)) | set(range(21, 26))
121

122
123
_TRANSLATION_TABLES = TRANSLATION_TABLES

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

# --- Sequence mask ----------------------------------------------------------

cdef class Mask:
    """The coordinates of a masked region.
    """

    @property
    def begin(self):
        return self.mask.begin

    @property
    def end(self):
        return self.mask.end

cdef class Masks:
    """A list of masked regions within a `~pyrodigal.Sequence`.
    """

    def __cinit__(self):
        self.masks = NULL
        self.capacity = 0
        self.length = 0

    def __init__(self):
        self._clear()

    def __dealloc__(self):
        PyMem_Free(self.masks)

    def __copy__(self):
        return self.copy()

    def __len__(self):
        return self.length

    def __getitem__(self, ssize_t index):
        cdef Mask mask
        if index < 0:
            index += <ssize_t> self.length
        if index >= <ssize_t> self.length or index < 0:
            raise IndexError("list index out of range")
        mask = Mask.__new__(Mask)
        mask.owner = self
        mask.mask = &self.masks[index]
        return mask

    def __sizeof__(self):
        return self.capacity * sizeof(_mask) + sizeof(self)

    cdef inline _mask* _add_mask(
        self,
        const int  begin,
        const int  end,
    ) nogil except NULL:
        """Add a single node to the vector, and return a pointer to that node.
        """

        cdef size_t old_capacity = self.capacity
        cdef _mask* mask

        if self.length >= self.capacity:
            self.capacity = MIN_MASKS_ALLOC if self.capacity == 0 else self.capacity*2
            with gil:
                self.masks = <_mask*> PyMem_Realloc(self.masks, self.capacity * sizeof(_mask))
                if self.masks == NULL:
                    raise MemoryError("Failed to reallocate mask array")
            memset(&self.masks[old_capacity], 0, (self.capacity - old_capacity) * sizeof(_mask))

        self.length += 1
        mask = &self.masks[self.length - 1]
        mask.begin = begin
        mask.end = end
        return mask

    cdef int _clear(self) nogil except 1:
        """Remove all masks from the vector.
        """
        cdef size_t old_length
        old_length, self.length = self.length, 0
        memset(self.masks, 0, old_length * sizeof(_mask))

    def clear(self):
        """Remove all masks from the vector.
        """
        with nogil:
            self._clear()

    cpdef Masks copy(self):
        cdef Masks new = Masks.__new__(Masks)
        new.capacity = self.capacity
        new.length = self.length
        new.masks = <_mask*> PyMem_Malloc(new.capacity * sizeof(_mask))
        if new.masks == NULL:
            raise MemoryError("Failed to allocate masks array")
        memcpy(new.masks, self.masks, new.capacity * sizeof(_mask))
        return new

222

223
224
225
# --- Input sequence ---------------------------------------------------------

cdef class Sequence:
226
    """A digitized input sequence.
227
228
229

    Attributes:
        gc (`float`): The GC content of the sequence, as a fraction.
230
231
        masks (`~pyrodigal.Masks`): A list of masked regions within the
            sequence.
232

233
234
    """

235
    # --- Class methods ------------------------------------------------------
236
237

    @classmethod
238
    def from_bytes(cls, const unsigned char[:] sequence, bint mask = False):
239
240
241
242
243
        """from_bytes(cls, sequence)\n--

        Create a new `Sequence` object from an ASCII-encoded sequence.

        Arguments:
244
245
            sequence (`bytes`): The ASCII-encoded sequence to use. Any
                object implementing the *buffer protocol* is supported.
246
247
            mask (`bool`): Enable region-masking for spans of unknown
                characters, preventing genes from being built across them.
248
249

        """
250
251
252
253
        cdef int           i
        cdef int           j
        cdef unsigned char letter
        cdef Sequence      seq
254
255
        cdef int           gc_count   = 0
        cdef int           mask_begin = -1
256
257
258

        seq = Sequence.__new__(Sequence)
        seq._allocate(sequence.shape[0])
259
        seq.masks = Masks.__new__(Masks)
260
261

        with nogil:
262
            for i in range(seq.slen):
263
264
                letter = sequence[i]
                if letter == b'A' or letter == b'a':
265
                    seq.digits[i] = nucleotide.A
266
                elif letter == b'T' or letter == b't':
267
                    seq.digits[i] = nucleotide.T
268
                elif letter == b'G' or letter == b'g':
269
                    seq.digits[i] = nucleotide.G
270
271
                    gc_count += 1
                elif letter == b'C' or letter == b'c':
272
                    seq.digits[i] = nucleotide.C
273
274
                    gc_count += 1
                else:
275
                    seq.digits[i] = nucleotide.N
276

277
278
279
            if seq.slen > 0:
                seq.gc = (<double> gc_count) / (<double> seq.slen)

280
281
            if mask:
                for i in range(seq.slen):
282
                    if seq.digits[i] == nucleotide.N:
283
284
285
286
287
288
289
                        if mask_begin == -1:
                            mask_begin = i
                    else:
                        if mask_begin != -1:
                            seq.masks._add_mask(mask_begin, i-1)
                            mask_begin = -1

290
291
292
        return seq

    @classmethod
293
    def from_string(cls, str sequence, bint mask = False):
294
295
296
297
298
299
        """from_string(cls, sequence)\n--

        Create a new `Sequence` object from a Unicode sequence.

        Arguments:
            sequence (`str`): The Unicode sequence to use.
300
301
            mask (`bool`): Enable region-masking for spans of unknown
                characters, preventing genes from being built across them.
302
303

        """
304
305
306
307
308
309
        cdef int      i
        cdef int      j
        cdef Py_UCS4  letter
        cdef Sequence seq
        cdef int      kind
        cdef void*    data
310
311
        cdef int      gc_count   = 0
        cdef int      mask_begin = -1
312
313
314
315
316
317
318
319

        # make sure the unicode string is in canonical form,
        # --> won't be needed anymore in Python 3.12
        IF SYS_VERSION_INFO_MAJOR <= 3 and SYS_VERSION_INFO_MINOR < 12:
            PyUnicode_READY(sequence)

        seq = Sequence.__new__(Sequence)
        seq._allocate(PyUnicode_GET_LENGTH(sequence))
320
        seq.masks = Masks.__new__(Masks)
321
322
323
324
325
326
327
328

        kind = PyUnicode_KIND(sequence)
        data = PyUnicode_DATA(sequence)

        with nogil:
            for i, j in enumerate(range(0, seq.slen * 2, 2)):
                letter = PyUnicode_READ(kind, data, i)
                if letter == u'A' or letter == u'a':
329
                    seq.digits[i] = nucleotide.A
330
                elif letter == u'T' or letter == u't':
331
                    seq.digits[i] = nucleotide.T
332
                elif letter == u'G' or letter == u'g':
333
                    seq.digits[i] = nucleotide.G
334
335
                    gc_count += 1
                elif letter == u'C' or letter == u'c':
336
                    seq.digits[i] = nucleotide.C
337
338
                    gc_count += 1
                else:
339
                    seq.digits[i] = nucleotide.N
340

341
342
343
            if seq.slen > 0:
                seq.gc = (<double> gc_count) / (<double> seq.slen)

344
345
            if mask:
                for i in range(seq.slen):
346
                    if seq.digits[i] == nucleotide.N:
347
348
349
350
351
352
353
                        if mask_begin == -1:
                            mask_begin = i
                    else:
                        if mask_begin != -1:
                            seq.masks._add_mask(mask_begin, i-1)
                            mask_begin = -1

354
355
        return seq

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    # --- Magic methods ------------------------------------------------------

    def __cinit__(self):
        self.slen = 0
        self.gc = 0.0
        self.digits = NULL
        self.masks = None

    def __dealloc__(self):
        PyMem_Free(self.digits)

    def __len__(self):
        return self.slen

    def __sizeof__(self):
        return self.slen * sizeof(uint8_t) + sizeof(self)

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    def __str__(self):
        cdef int     i
        cdef Py_UCS4 nuc

        IF SYS_VERSION_INFO_MAJOR <= 3 and SYS_VERSION_INFO_MINOR < 7 and SYS_IMPLEMENTATION_NAME == "pypy":
            cdef bytes dna
            cdef void* data
            # create an empty byte buffer that we can write to
            dna = PyBytes_FromStringAndSize(NULL, self.slen)
            data = <void*> PyBytes_AsString(dna)
        ELSE:
            cdef unicode dna
            cdef int     kind
            cdef void*   data
            # create an empty string that we can write to
            dna  = PyUnicode_New(self.slen, 0x7F)
            kind = PyUnicode_KIND(dna)
            data = PyUnicode_DATA(dna)

        with nogil:
            for i in range(self.slen):
                nuc = _letters[self.digits[i]]
                IF SYS_VERSION_INFO_MAJOR <= 3 and SYS_VERSION_INFO_MINOR < 7 and SYS_IMPLEMENTATION_NAME == "pypy":
                    (<char*> data)[i] = nuc
                ELSE:
                    PyUnicode_WRITE(kind, data, i, nuc)

        IF SYS_VERSION_INFO_MAJOR <= 3 and SYS_VERSION_INFO_MINOR < 7 and SYS_IMPLEMENTATION_NAME == "pypy":
            return dna.decode("ascii")
        ELSE:
            return dna

405
406
407
408
409
410
411
412
413
414
415
    # --- C interface -------------------------------------------------------

    cdef int _allocate(self, int slen) except 1:
        self.slen = slen
        self.digits = <uint8_t*> PyMem_Malloc(slen * sizeof(uint8_t))
        if self.digits == NULL:
            raise MemoryError()
        with nogil:
            memset(self.digits, 0, slen * sizeof(uint8_t))
        return 0

416
417
418
419
420
421
422
423
    cdef char _amino(
        self,
        int i,
        int tt,
        int strand = 1,
        bint is_init = False,
        char unknown_residue = b"X"
    ) nogil:
424
425
426
427
        cdef uint8_t x0
        cdef uint8_t x1
        cdef uint8_t x2

428
429
430
431
432
        if _is_stop(self.digits, self.slen, i, tt, strand):
            return b"*"
        if _is_start(self.digits, self.slen, i, tt, strand) and is_init:
            return b"M"

433
434
435
436
437
        if strand == 1:
            x0 = self.digits[i]
            x1 = self.digits[i+1]
            x2 = self.digits[i+2]
        else:
438
439
440
            x0 = _complement[self.digits[self.slen - 1 - i]]
            x1 = _complement[self.digits[self.slen - 2 - i]]
            x2 = _complement[self.digits[self.slen - 3 - i]]
441

442
        if x0 == nucleotide.T and x1 == nucleotide.T and (x2 == nucleotide.T or x2 == nucleotide.C):
443
            return b"F"
444
        if x0 == nucleotide.T and x1 == nucleotide.T and (x2 == nucleotide.A or x2 == nucleotide.G):
445
            return b"L"
446
        if x0 == nucleotide.T and x1 == nucleotide.C and x2 != nucleotide.N:
447
            return b"S"
448
        if x0 == nucleotide.T and x1 == nucleotide.A and x2 == nucleotide.T:
449
            return b"Y"
450
        if x0 == nucleotide.T and x1 == nucleotide.A and x2 == nucleotide.C:
451
            return b"Y"
452
        if x0 == nucleotide.T and x1 == nucleotide.A and x2 == nucleotide.A:
453
454
455
456
            if tt == 6:
                return b"Q"
            elif tt == 14:
                return b"Y"
457
        if x0 == nucleotide.T and x1 == nucleotide.A and x2 == nucleotide.G:
458
459
460
461
            if tt == 6 or tt == 15:
                return b"Q"
            elif tt == 22:
                return b"L"
462
        if x0 == nucleotide.T and x1 == nucleotide.G and (x2 == nucleotide.T or x2 == nucleotide.C):
463
            return b"C"
464
        if x0 == nucleotide.T and x1 == nucleotide.G and x2 == nucleotide.A:
465
            return b"G" if tt == 25 else b"W"
466
        if x0 == nucleotide.T and x1 == nucleotide.G and x2 == nucleotide.G:
467
            return b"W"
468
        if x0 == nucleotide.C and x1 == nucleotide.T and (x2 == nucleotide.T or x2 == nucleotide.C or x2 == nucleotide.A):
469
            return b"T" if tt == 3 else b"L"
470
        if x0 == nucleotide.C and x1 == nucleotide.T and x2 == nucleotide.G:
471
            return b"T" if tt == 3 else b"S" if tt == 12 else b"L"
472
        if x0 == nucleotide.C and x1 == nucleotide.C and x2 != nucleotide.N:
473
            return b"P"
474
        if x0 == nucleotide.C and x1 == nucleotide.A and (x2 == nucleotide.T or x2 == nucleotide.C):
475
            return b"H"
476
        if x0 == nucleotide.C and x1 == nucleotide.A and (x2 == nucleotide.A or x2 == nucleotide.G):
477
            return b"Q"
478
        if x0 == nucleotide.C and x1 == nucleotide.G and x2 != nucleotide.N:
479
            return b"R"
480
        if x0 == nucleotide.A and x1 == nucleotide.T and (x2 == nucleotide.T or x2 == nucleotide.C):
481
            return b"I"
482
        if x0 == nucleotide.A and x1 == nucleotide.T and x2 == nucleotide.A:
483
            return b"M" if tt == 2 or tt == 3 or tt == 5 or tt == 13 or tt == 22 else b"I"
484
        if x0 == nucleotide.A and x1 == nucleotide.T and x2 == nucleotide.G:
485
            return b"M"
486
        if x0 == nucleotide.A and x1 == nucleotide.C and x2 != nucleotide.N:
487
            return b"T"
488
        if x0 == nucleotide.A and x1 == nucleotide.A and (x2 == nucleotide.T or x2 == nucleotide.C):
489
            return b"N"
490
        if x0 == nucleotide.A and x1 == nucleotide.A and x2 == nucleotide.A:
491
            return b"N" if tt == 9 or tt == 14 or tt == 21 else b"K"
492
        if x0 == nucleotide.A and x1 == nucleotide.A and x2 == nucleotide.G:
493
            return b"K"
494
        if x0 == nucleotide.A and x1 == nucleotide.G and (x2 == nucleotide.T or x2 == nucleotide.C):
495
            return b"S"
496
        if x0 == nucleotide.A and x1 == nucleotide.G and (x2 == nucleotide.A or x2 == nucleotide.G):
497
            return b"G" if tt == 13 else b"S" if tt == 5 or tt == 9 or tt == 14 or tt == 21 else b"R"
498
        if x0 == nucleotide.G and x1 == nucleotide.T and x2 != nucleotide.N:
499
            return b"V"
500
        if x0 == nucleotide.G and x1 == nucleotide.C and x2 != nucleotide.N:
501
            return b"A"
502
        if x0 == nucleotide.G and x1 == nucleotide.A and (x2 == nucleotide.T or x2 == nucleotide.C):
503
            return b"D"
504
        if x0 == nucleotide.G and x1 == nucleotide.A and (x2 == nucleotide.A or x2 == nucleotide.G):
505
            return b"E"
506
        if x0 == nucleotide.G and x1 == nucleotide.G and x2 != nucleotide.N:
507
508
            return b"G"

509
        return unknown_residue
510

511
512
513
514
515
516
517
    cdef int _shine_dalgarno_exact(
        self,
        const int pos,
        const int start,
        const _training* tinf,
        const int strand
    ) nogil:
518
519
520
521
522
523
524
525
526
527
528
529
530
        cdef int i
        cdef int j
        cdef int k
        cdef int mism
        cdef int rdis
        cdef int limit
        cdef int max_val
        cdef int cmp_val
        cdef int cur_val = 0
        cdef int match[6]
        cdef int cur_ctr
        cdef int dis_flag

531
532
        # reset the match array
        match[0] = match[1] = match[2] = match[3] = match[4] = match[5] = -10
533
534

        # Compare the 6-base region to AGGAGG
535
        limit = min(6, start - 4 - pos)
536
        for i in range(limit):
537
            if i%3 == 0 and _is_a(self.digits, self.slen, pos+i, strand):
538
                match[i] = 2
539
            elif i%3 != 0 and _is_g(self.digits, self.slen, pos+i, strand):
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
                match[i] = 3
            else:
                match[i] = -10

        # Find the maximally scoring motif
        max_val = 0
        for i in range(limit, 2, -1):
            for j in range(limit+1-i):
                cur_ctr = -2
                mism = 0;
                for k in range(j, j+i):
                    cur_ctr += match[k]
                    if match[k] < 0:
                        mism += 1
                if mism > 0:
                    continue
                rdis = start - (pos + j + i)
                if rdis < 5 and i < 5:
                    dis_flag = 2
                elif rdis < 5 and i >= 5:
                    dis_flag = 1
                elif rdis > 10 and rdis <= 12 and i < 5:
                    dis_flag = 1
                elif rdis > 10 and rdis <= 12 and i >= 5:
                    dis_flag = 2
                elif rdis >= 13:
                    dis_flag = 3
                else:
                    dis_flag = 0
                if rdis > 15 or cur_ctr < 6.0:
                    continue

                # Exact-Matching RBS Motifs
                if cur_ctr < 6:
                    cur_val = 0
                elif cur_ctr == 6 and dis_flag == 2:
                    cur_val = 1
                elif cur_ctr == 6 and dis_flag == 3:
                    cur_val = 2
                elif cur_ctr == 8 and dis_flag == 3:
                    cur_val = 3
                elif cur_ctr == 9 and dis_flag == 3:
                    cur_val = 3
                elif cur_ctr == 6 and dis_flag == 1:
                    cur_val = 6
                elif cur_ctr == 11 and dis_flag == 3:
                    cur_val = 10
                elif cur_ctr == 12 and dis_flag == 3:
                    cur_val = 10
                elif cur_ctr == 14 and dis_flag == 3:
                    cur_val = 10
                elif cur_ctr == 8 and dis_flag == 2:
                    cur_val = 11
                elif cur_ctr == 9 and dis_flag == 2:
                    cur_val = 11
                elif cur_ctr == 8 and dis_flag == 1:
                    cur_val = 12
                elif cur_ctr == 9 and dis_flag == 1:
                    cur_val = 12
                elif cur_ctr == 6 and dis_flag == 0:
                    cur_val = 13
                elif cur_ctr == 8 and dis_flag == 0:
                    cur_val = 15
                elif cur_ctr == 9 and dis_flag == 0:
                    cur_val = 16
                elif cur_ctr == 11 and dis_flag == 2:
                    cur_val = 20
                elif cur_ctr == 11 and dis_flag == 1:
                    cur_val = 21
                elif cur_ctr == 11 and dis_flag == 0:
                    cur_val = 22
                elif cur_ctr == 12 and dis_flag == 2:
                    cur_val = 20
                elif cur_ctr == 12 and dis_flag == 1:
                    cur_val = 23
                elif cur_ctr == 12 and dis_flag == 0:
                    cur_val = 24
                elif cur_ctr == 14 and dis_flag == 2:
                    cur_val = 25
                elif cur_ctr == 14 and dis_flag == 1:
                    cur_val = 26
                elif cur_ctr == 14 and dis_flag == 0:
                    cur_val = 27

                if tinf.rbs_wt[cur_val] < tinf.rbs_wt[max_val]:
                    continue
                if tinf.rbs_wt[cur_val] == tinf.rbs_wt[max_val] and cur_val < max_val:
                    continue
                max_val = cur_val

        return max_val

632
633
634
635
636
637
638
    cdef int _shine_dalgarno_mm(
        self,
        const int pos,
        const int start,
        const _training* tinf,
        const int strand
    ) nogil:
639
640
641
642
643
644
645
646
647
648
649
650
651
        cdef int i
        cdef int j
        cdef int k
        cdef int mism
        cdef int rdis
        cdef int limit
        cdef int max_val
        cdef int cmp_val
        cdef int cur_val = 0
        cdef int match[6]
        cdef int cur_ctr
        cdef int dis_flag

652
653
        # reset the match array
        match[0] = match[1] = match[2] = match[3] = match[4] = match[5] = -10
654
655

        # Compare the 6-base region to AGGAGG
656
        limit = min(6, start - 4 - pos)
657
658
        for i in range(limit):
            if i%3 == 0:
659
                match[i] = -3 + 5*_is_a(self.digits, self.slen, pos+i, strand)
660
            else:
661
                match[i] = -2 + 5*_is_g(self.digits, self.slen, pos+i, strand)
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724

        # Find the maximally scoring motif
        max_val = 0
        for i in range(limit, 4, -1):
            for j in range(limit+1-i):
                cur_ctr = -2
                mism = 0;
                for k in range(j, j+i):
                    cur_ctr += match[k]
                    if match[k] < 0.0:
                        mism += 1
                        if k <= j+1 or k >= j+i-2:
                          cur_ctr -= 10
                if mism != 1:
                    continue
                rdis = start - (pos + j + i)
                if rdis < 5:
                    dis_flag = 1
                elif rdis > 10 and rdis <= 12:
                    dis_flag = 2
                elif rdis >= 13:
                    dis_flag = 3
                else:
                    dis_flag = 0
                if rdis > 15 or cur_ctr < 6:
                    continue

                # Single-Matching RBS Motifs
                if cur_ctr < 6:
                    cur_val = 0
                elif cur_ctr == 6 and dis_flag == 3:
                    cur_val = 2
                elif cur_ctr == 7 and dis_flag == 3:
                    cur_val = 2
                elif cur_ctr == 9 and dis_flag == 3:
                    cur_val = 3
                elif cur_ctr == 6 and dis_flag == 2:
                    cur_val = 4
                elif cur_ctr == 6 and dis_flag == 1:
                    cur_val = 5
                elif cur_ctr == 6 and dis_flag == 0:
                    cur_val = 9
                elif cur_ctr == 7 and dis_flag == 2:
                    cur_val = 7
                elif cur_ctr == 7 and dis_flag == 1:
                    cur_val = 8
                elif cur_ctr == 7 and dis_flag == 0:
                    cur_val = 14
                elif cur_ctr == 9 and dis_flag == 2:
                    cur_val = 17
                elif cur_ctr == 9 and dis_flag == 1:
                    cur_val = 18
                elif cur_ctr == 9 and dis_flag == 0:
                    cur_val = 19

                if tinf.rbs_wt[cur_val] < tinf.rbs_wt[max_val]:
                    continue
                if tinf.rbs_wt[cur_val] == tinf.rbs_wt[max_val] and cur_val < max_val:
                    continue
                max_val = cur_val

        return max_val

725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
    # --- Python interface ---------------------------------------------------

    cpdef int shine_dalgarno(
        self,
        int pos,
        int start,
        TrainingInfo training_info,
        int strand=1,
        bint exact=True
    ) except -1:
        """shine_dalgarno(self, pos, start, training_info, strand=1, exact=True)\n--

        Find the highest scoring Shine-Dalgarno motif upstream of ``start``.

        Arguments:
            pos (`int`): The position where to look for the Shine-Dalgarno
                motif. Must be upstream of ``start`` (before or after,
                depending on the strand).
            start (`int`): The position of the start codon being considered.
            training_info (`~pyrodigal.TrainingInfo`): The training info
                containing the weights for the different ribosome weights.

        Keyword Arguments:
            strand (`int`): The strand to scan.
            exact (`bool`): `True` to score Shine-Dalgarno motifs matching
                exactly *AGGAGG*, `False` to allow one base mismatch.

        Returns:
            `int`: The index of the highest scoring Shine-Dalgarno motif.

        Raises:
            `ValueError`: On invalid `strand`, `pos` or `start` values.

        """
        if strand != 1 and strand != -1:
            raise ValueError(f"Invalid strand: {strand!r} (must be +1 or -1)")
        if pos < 0:
            raise ValueError(f"`pos` must be positive")
        if start < 0:
            raise ValueError(f"`start` must be positive")

        if strand == 1 and pos > start - 5:
            raise ValueError(f"`pos` is too close to `start` (must be at most `start` - 5)")
        elif strand == -1 and pos < start + 6:
            raise ValueError(f"`pos` is too close to `start` (must be at most `start + 6`)")

        cdef int phase
        with nogil:
            if exact:
                phase = self._shine_dalgarno_exact(pos, start, training_info.tinf, strand)
            else:
                phase = self._shine_dalgarno_mm(pos, start, training_info.tinf, strand)

        return phase

780

781
782
# --- Connection Scorer ------------------------------------------------------

783
784
785
786
787
788
789
790
_TARGET_CPU           = TARGET_CPU
_AVX2_RUNTIME_SUPPORT = False
_NEON_RUNTIME_SUPPORT = False
_SSE2_RUNTIME_SUPPORT = False
_AVX2_BUILD_SUPPORT   = False
_NEON_BUILD_SUPPORT   = False
_SSE2_BUILD_SUPPORT   = False

791
792
793
794
795
796
IF TARGET_CPU == "x86":
    cdef X86Info cpu_info = GetX86Info()
    _SSE2_RUNTIME_SUPPORT = cpu_info.features.sse2 != 0
    _AVX2_RUNTIME_SUPPORT = cpu_info.features.avx2 != 0
    _SSE2_BUILD_SUPPORT   = SSE2_BUILD_SUPPORT
    _AVX2_BUILD_SUPPORT   = AVX2_BUILD_SUPPORT
797
798
799
800
801
802
ELIF TARGET_CPU == "arm":
    cdef ArmInfo cpu_info = GetArmInfo()
    _NEON_RUNTIME_SUPPORT = cpu_info.features.neon != 0
    _NEON_BUILD_SUPPORT   = NEON_BUILD_SUPPORT
ELIF TARGET_CPU == "aarch64":
    _NEON_RUNTIME_SUPPORT = True
803
    _NEON_BUILD_SUPPORT   = NEON_BUILD_SUPPORT
804
805
806
807
808
809

cdef enum simd_backend:
    NONE = 0
    SSE2 = 1
    AVX2 = 2
    NEON = 3
810
    GENERIC = 4
811

812
813
cdef class ConnectionScorer:

814
815
    # --- Magic methods ------------------------------------------------------

816
817
818
819
820
821
822
    def __cinit__(self):
        self.capacity = 0
        self.skip_connection = self.skip_connection_raw = NULL
        self.node_types      = self.node_types_raw      = NULL
        self.node_strands    = self.node_strands_raw    = NULL
        self.node_frames     = self.node_frames_raw     = NULL

823
    def __init__(self, str backend="detect"):
824
825
        IF TARGET_CPU == "x86":
            if backend =="detect":
826
                self.backend = simd_backend.NONE
827
828
829
830
831
832
833
834
835
                IF SSE2_BUILD_SUPPORT:
                    if _SSE2_RUNTIME_SUPPORT:
                        self.backend = simd_backend.SSE2
                IF AVX2_BUILD_SUPPORT:
                    if _AVX2_RUNTIME_SUPPORT:
                        self.backend = simd_backend.AVX2
            elif backend == "sse":
                IF not SSE2_BUILD_SUPPORT:
                    raise RuntimeError("Extension was compiled without SSE2 support")
836
837
838
839
                ELSE:
                    if not _SSE2_RUNTIME_SUPPORT:
                        raise RuntimeError("Cannot run SSE2 instructions on this machine")
                    self.backend = simd_backend.SSE2
840
841
842
            elif backend == "avx":
                IF not AVX2_BUILD_SUPPORT:
                    raise RuntimeError("Extension was compiled without AVX2 support")
843
844
845
846
                ELSE:
                    if not _AVX2_RUNTIME_SUPPORT:
                        raise RuntimeError("Cannot run AVX2 instructions on this machine")
                    self.backend = simd_backend.AVX2
847
848
            elif backend == "generic":
                self.backend = simd_backend.GENERIC
849
850
851
852
            elif backend is None:
                self.backend = simd_backend.NONE
            else:
                raise ValueError(f"Unsupported backend on this architecture: {backend}")
853
854
        ELIF TARGET_CPU == "arm" or TARGET_CPU == "aarch64":
            if backend =="detect":
855
                self.backend = simd_backend.NONE
856
857
858
859
860
861
862
863
864
865
                IF NEON_BUILD_SUPPORT:
                    if _NEON_RUNTIME_SUPPORT:
                        self.backend = simd_backend.NEON
            elif backend == "neon":
                IF not NEON_BUILD_SUPPORT:
                    raise RuntimeError("Extension was compiled without NEON support")
                ELSE:
                    if not _NEON_RUNTIME_SUPPORT:
                        raise RuntimeError("Cannot run NEON instructions on this machine")
                    self.backend = simd_backend.NEON
866
867
            elif backend == "generic":
                self.backend = simd_backend.GENERIC
868
869
870
871
            elif backend is None:
                self.backend = simd_backend.NONE
            else:
                raise ValueError(f"Unsupported backend on this architecture: {backend}")
872
        ELSE:
873
874
            if backend == "detect":
                self.backend = simd_backend.NONE
875
876
877
878
879
880
            if backend == "generic":
                self.backend = simd_backend.GENERIC
            elif backend is None:
                self.backend = simd_backend.NONE
            else:
                raise ValueError(f"Unsupported backend on this architecture: {backend}")
881

882
883
884
885
886
887
    def __dealloc__(self):
        PyMem_Free(self.node_types_raw)
        PyMem_Free(self.node_strands_raw)
        PyMem_Free(self.node_frames_raw)
        PyMem_Free(self.skip_connection_raw)

888
889
890
    # --- C interface --------------------------------------------------------

    cdef int _index(self, Nodes nodes) nogil except -1:
891
        cdef size_t i
892
893
894
        # nothing to be done if we are using the Prodigal code
        if self.backend == simd_backend.NONE:
            return 0
895
        # reallocate if needed
896
897
898
899
        if self.capacity < nodes.length:
            with gil:
                # reallocate new memory
                self.skip_connection_raw = <uint8_t*> PyMem_Realloc(self.skip_connection_raw, nodes.length * sizeof(uint8_t) + 0x1F)
900
901
902
                self.node_types_raw      = <uint8_t*> PyMem_Realloc(self.node_types_raw, nodes.length      * sizeof(uint8_t) + 0x1F)
                self.node_strands_raw    = <int8_t*>  PyMem_Realloc(self.node_strands_raw, nodes.length    * sizeof(int8_t)  + 0x1F)
                self.node_frames_raw     = <uint8_t*> PyMem_Realloc(self.node_frames_raw, nodes.length     * sizeof(uint8_t) + 0x1F)
903
904
905
906
907
908
909
910
911
912
913
914
915
                # check that allocations were successful
                if self.skip_connection_raw == NULL:
                    raise MemoryError("Failed to allocate memory for scoring bypass index")
                if self.node_types_raw == NULL:
                    raise MemoryError("Failed to allocate memory for node type array")
                if self.node_strands_raw == NULL:
                    raise MemoryError("Failed to allocate memory for node strand array")
                if self.node_frames_raw == NULL:
                    raise MemoryError("Failed to allocate memory for node frame array")
            # record new capacity
            self.capacity = nodes.length
            # compute pointers to aligned memory
            self.skip_connection = <uint8_t*> ((<uintptr_t> self.skip_connection_raw + 0x1F) & (~0x1F))
916
917
918
919
            self.node_types      = <uint8_t*> ((<uintptr_t> self.node_types_raw      + 0x1F) & (~0x1F))
            self.node_strands    = <int8_t*>  ((<uintptr_t> self.node_strands_raw    + 0x1F) & (~0x1F))
            self.node_frames     = <uint8_t*> ((<uintptr_t> self.node_frames_raw     + 0x1F) & (~0x1F))
        # copy data from the array of nodes
920
        for i in range(nodes.length):
921
922
923
924
925
            self.node_types[i]      = nodes.nodes[i].type
            self.node_strands[i]    = nodes.nodes[i].strand
            self.node_frames[i]     = nodes.nodes[i].ndx % 3
            self.skip_connection[i] = False
        # return 0 if no exceptions were raised
926
927
        return 0

928
929
930
931
932
    cdef int _compute_skippable(
        self,
        int min,
        int i
    ) nogil:
933
934
935
        IF AVX2_BUILD_SUPPORT:
            if self.backend == simd_backend.AVX2:
                skippable_avx(self.node_strands, self.node_types, self.node_frames, min, i, self.skip_connection)
936
                return 0
937
938
939
        IF SSE2_BUILD_SUPPORT:
            if self.backend == simd_backend.SSE2:
                skippable_sse(self.node_strands, self.node_types, self.node_frames, min, i, self.skip_connection)
940
                return 0
941
942
943
        IF NEON_BUILD_SUPPORT:
            if self.backend == simd_backend.NEON:
                skippable_neon(self.node_strands, self.node_types, self.node_frames, min, i, self.skip_connection)
944
                return 0
945
946
947
        if self.backend == simd_backend.GENERIC:
            skippable_generic(self.node_strands, self.node_types, self.node_frames, min, i, self.skip_connection)
            return 0
948
949
        return 0

950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
    @staticmethod
    cdef void _score_connection(
        Nodes nodes,
        const int p1,
        const int p2,
        const _training* tinf,
        const bint final
    ) nogil:
        cdef _node* n1 = &nodes.nodes[p1]
        cdef _node* n2 = &nodes.nodes[p2]
        cdef _node* n3

        cdef int i
        cdef int bnd
        cdef int ovlp  = 0
        cdef int maxfr = -1
        cdef int left  = n1.ndx
        cdef int right = n2.ndx

        cdef double maxval
        cdef double score   = 0.0
        cdef double scr_mod = 0.0

        # NOTE(@althonos): We can skip checking for invalid connections here
        #                  because the connection scorer already checked for
        #                  them using the SIMD filter

        # --- Edge Artifacts ---
        if n1.traceb == -1 and n1.strand == 1 and n1.type == node_type.STOP:
            return
        elif n1.traceb == -1 and n1.strand == -1 and n1.type != node_type.STOP:
            return

        # --- Genes ---
        # 5'fwd->3'fwd
        elif n1.strand == 1 and n2.strand == 1 and n1.type != node_type.STOP and n2.type == node_type.STOP:
            if n2.stop_val >= n1.ndx:
                return
            right += 2
            if final:
                score = n1.cscore + n1.sscore
            else:
                scr_mod = tinf.bias[0]*n1.gc_score[0] + tinf.bias[1]*n1.gc_score[1] + tinf.bias[2]*n1.gc_score[2]
        # 3'rev->5'rev */
        elif n1.strand == -1 and n2.strand == -1 and n1.type == node_type.STOP and n2.type != node_type.STOP:
            if n1.stop_val <= n2.ndx:
                return
            left -= 2;
            if final == 0:
                scr_mod = tinf.bias[0]*n2.gc_score[0] + tinf.bias[1]*n2.gc_score[1] + tinf.bias[2]*n2.gc_score[2]
            elif final:
For faster browsing, not all history is shown. View entire blame