Commit 6013273e authored by Martin Larralde's avatar Martin Larralde
Browse files

Implement pickle protocol for `pyrodigal.Nodes`

parent cab5bb41
......@@ -176,6 +176,9 @@ cdef class Nodes:
cdef size_t capacity
cdef size_t length
cpdef list __getstate__(self)
cpdef object __setstate__(self, list state)
cdef inline _node* _add_node(
const int ndx,
import threading
import typing
from typing import Iterable, Iterator, Optional, Set, TextIO, Tuple, Union
from typing import Iterable, Iterator, List, Dict, Optional, Set, TextIO, Tuple, Union
# --- Globals ----------------------------------------------------------------
......@@ -98,6 +98,8 @@ class Nodes(typing.Sequence[Node]):
def __getitem__(self, index: int) -> Node: ...
def __iter__(self) -> Iterator[Node]: ...
def __reversed__(self) -> Iterator[Node]: ...
def __getstate__(self) -> List[Dict[str, object]]: ...
def __setstate__(self, state: List[Dict[str, object]]) -> None: ...
def copy(self) -> Nodes: ...
def clear(self) -> None: ...
def extract(
......@@ -50,6 +50,7 @@ References:
from cpython.buffer cimport PyBUF_READ, PyBUF_WRITE
from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AsString
from cpython.exc cimport PyErr_CheckSignals
from cpython.list cimport PyList_New, PyList_SET_ITEM
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
from cpython.memoryview cimport PyMemoryView_FromMemory
from cpython.ref cimport Py_INCREF
......@@ -1459,6 +1460,98 @@ cdef class Nodes:
def __sizeof__(self):
return self.capacity * sizeof(_node) + sizeof(self)
cpdef list __getstate__(self):
cdef int i
return [
"type": self.nodes[i].type,
"edge": self.nodes[i].edge,
"ndx": self.nodes[i].ndx,
"strand": self.nodes[i].strand,
"stop_val": self.nodes[i].stop_val,
"star_ptr": [
"gc_bias": self.nodes[i].gc_bias,
"gc_score": [
"cscore": self.nodes[i].cscore,
"gc_cont": self.nodes[i].gc_cont,
"rbs": [
"motif": {
"ndx": self.nodes[i].mot.ndx,
"len": self.nodes[i].mot.len,
"spacer": self.nodes[i].mot.spacer,
"spacendx": self.nodes[i].mot.spacendx,
"score": self.nodes[i].mot.spacendx,
"uscore": self.nodes[i].uscore,
"tscore": self.nodes[i].tscore,
"rscore": self.nodes[i].rscore,
"sscore": self.nodes[i].sscore,
"traceb": self.nodes[i].traceb,
"tracef": self.nodes[i].tracef,
"ov_mark": self.nodes[i].ov_mark,
"score": self.nodes[i].score,
"elim": self.nodes[i].elim,
for i in range(self.length)
cpdef object __setstate__(self, list state):
"""__setstate__(self, state)\n--
cdef int i
cdef dict node
cdef dict motif
cdef size_t old_capacity = self.capacity
# realloc to the exact number of nodes
self.length = len(state)
self.capacity = MIN_NODES_ALLOC if self.length == 0 else self.length
self.nodes = <_node*> PyMem_Realloc(self.nodes, self.capacity * sizeof(_node))
if self.nodes == NULL:
raise MemoryError("Failed to reallocate node array")
# copy node data from the state dictionary
for i, node in enumerate(state):
motif = node["motif"]
self.nodes[i].type = node["type"]
self.nodes[i].edge = node["edge"]
self.nodes[i].ndx = node["ndx"]
self.nodes[i].strand = node["strand"]
self.nodes[i].stop_val = node["stop_val"]
self.nodes[i].star_ptr = node["star_ptr"]
self.nodes[i].gc_bias = node["gc_bias"]
self.nodes[i].gc_score = node["gc_score"]
self.nodes[i].cscore = node["cscore"]
self.nodes[i].gc_cont = node["gc_cont"]
self.nodes[i].rbs = node["rbs"]
self.nodes[i].mot.ndx = motif["ndx"]
self.nodes[i].mot.len = motif["len"]
self.nodes[i].mot.spacer = motif["spacer"]
self.nodes[i].mot.spacendx = motif["spacendx"]
self.nodes[i].mot.score = motif["score"]
self.nodes[i].uscore = node["uscore"]
self.nodes[i].tscore = node["tscore"]
self.nodes[i].rscore = node["rscore"]
self.nodes[i].sscore = node["sscore"]
self.nodes[i].traceb = node["traceb"]
self.nodes[i].tracef = node["tracef"]
self.nodes[i].ov_mark = node["ov_mark"]
self.nodes[i].score = node["score"]
self.nodes[i].elim = node["elim"]
# --- C interface --------------------------------------------------------
cdef inline _node* _add_node(
......@@ -2,6 +2,7 @@ import
import gzip
import os
import sys
import pickle
import unittest
from .. import Nodes, Sequence
......@@ -12,6 +13,19 @@ from .fasta import parse
class TestNodes(unittest.TestCase):
def assertNodeEqual(self, n1, n2):
self.assertEqual(n1.index, n2.index, "indices differ")
self.assertEqual(n1.strand, n2.strand, "strands differ")
self.assertEqual(n1.type, n2.type, "types differ")
self.assertEqual(n1.edge, n2.edge, "edge differ")
self.assertEqual(n1.gc_bias, n2.gc_bias, "GC biases differ")
self.assertEqual(n1.cscore, n2.cscore, "cscores differ")
self.assertEqual(n1.gc_cont, n2.gc_cont, "GC contents differ")
self.assertEqual(n1.score, n2.score, "GC contents differ")
self.assertEqual(n1.rscore, n2.rscore, "rscores differ")
self.assertEqual(n1.sscore, n2.sscore, "sscores differ")
self.assertEqual(n1.tscore, n2.tscore, "tscores differ")
def setUpClass(cls):
data = os.path.realpath(os.path.join(__file__, "..", "data"))
......@@ -39,10 +53,26 @@ class TestNodes(unittest.TestCase):
nodes1.extract(seq, translation_table=tt)
nodes2 = nodes1.copy()
for n1, n2 in zip(nodes1, nodes2):
self.assertEqual(n1.type, n2.type)
self.assertNodeEqual(n1, n2)
def test_copy_empty(self):
nodes = Nodes()
copy = nodes.copy()
self.assertEqual(len(nodes), 0)
self.assertEqual(len(copy), 0)
def test_pickle(self):
tt = METAGENOMIC_BINS[0].training_info.translation_table
seq = Sequence.from_string(self.record.seq)
nodes1 = Nodes()
nodes1.extract(seq, translation_table=tt)
nodes2 = pickle.loads(pickle.dumps(nodes1))
self.assertEqual(len(nodes1), len(nodes2), "lengths differ")
for n1, n2 in zip(nodes1, nodes2):
self.assertNodeEqual(n1, n2)
def test_pickle_empty(self):
nodes1 = Nodes()
nodes2 = pickle.loads(pickle.dumps(nodes1))
self.assertEqual(len(nodes1), 0)
self.assertEqual(len(nodes2), 0)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment