Commit 382787a9 authored by Martin Larralde's avatar Martin Larralde
Browse files

Implement pickle protocol for `pyrodigal.TrainingInfo`

parent 6013273e
......@@ -291,6 +291,9 @@ cdef class TrainingInfo:
cdef bint owned
cdef _training* tinf
cpdef dict __getstate__(self)
cpdef object __setstate__(self, dict state)
cdef void _update_motif_counts(double mcnt[4][4][4096], double *zero, Sequence seq, _node* nod, int stage) nogil
......@@ -190,6 +190,8 @@ class TrainingInfo:
def load(cls, fp: typing.BinaryIO) -> TrainingInfo: ...
def __init__(self, gc: float, start_weight: float = 4.35, translation_table: int = 11) -> None: ...
def __repr__(self) -> str: ...
def __getstate__(self) -> Dict[str, object]: ...
def __setstate__(self, state: Dict[str, object]) -> None: ...
def translation_table(self) -> int: ...
......@@ -1463,7 +1463,7 @@ cdef class Nodes:
cpdef list __getstate__(self):
cdef int i
cdef size_t i
return [
"type": self.nodes[i].type,
......@@ -1511,7 +1511,7 @@ cdef class Nodes:
cpdef object __setstate__(self, list state):
"""__setstate__(self, state)\n--
cdef int i
cdef size_t i
cdef dict node
cdef dict motif
cdef size_t old_capacity = self.capacity
......@@ -3409,7 +3409,7 @@ cdef class TrainingInfo:
# --- Magic methods ----------------------------------------------------
def __cinit__(self):
self.owned = False
self.owned = True
self.tinf = NULL
def __init__(self, double gc, double start_weight=4.35, int translation_table=11):
......@@ -3442,6 +3442,64 @@ cdef class TrainingInfo:
cpdef dict __getstate__(self):
assert self.tinf != NULL
cdef int i
cdef int j
cdef int k
return {
"gc": self.tinf.gc,
"translation_table": self.tinf.trans_table,
"start_weight": self.tinf.st_wt,
"bias": self.bias,
"type_weights": self.type_weights,
"uses_sd": self.uses_sd,
"rbs_wt": [
self.tinf.rbs_wt[i] for i in range(28)
"ups_comp": [
[self.tinf.ups_comp[i][j] for j in range(4)]
for i in range(32)
"mot_wt": [
[self.tinf.mot_wt[i][j][k] for k in range(4096)]
for j in range(4)
for i in range(4)
"no_mot": self.tinf.no_mot,
"gene_dc": [
self.tinf.gene_dc[i] for i in range(4096)
cpdef object __setstate__(self, dict state):
"""__setstate__(self, state)\n--
cdef int i
cdef int j
cdef int k
# allocate memory if possible / needed
if not self.owned:
raise RuntimeError("Cannot call `__setstate__` on a shared `TrainingInfo` instance")
if self.tinf == NULL:
# copy data
self.tinf.gc = state["gc"]
self.tinf.trans_table = state["translation_table"]
self.tinf.st_wt = state["start_weight"]
self.tinf.bias = state["bias"]
self.tinf.type_wt = state["type_weights"]
self.tinf.uses_sd = state["uses_sd"]
self.tinf.no_mot = state["no_mot"]
self.tinf.rbs_wt = state["rbs_wt"]
self.tinf.ups_comp = state["ups_comp"]
self.tinf.mot_wt = state["mot_wt"]
self.tinf.gene_dc = state["gene_dc"]
# --- Properties -------------------------------------------------------
import abc
import gzip
import os
import pickle
import textwrap
import unittest
import warnings
......@@ -331,3 +332,35 @@ class TestSingle(_OrfFinderTestCase, unittest.TestCase):
genes = p.find_genes("")
self.assertEqual(len(genes), 0)
self.assertRaises(StopIteration, next, iter(genes))
def test_training_info_pickle(self):
record = load_record("SRR492066")
# train separately
p1 = OrfFinder(meta=False)
with warnings.catch_warnings():
# pickle/unpickle the OrfFinder
ti = pickle.loads(pickle.dumps(p1.training_info))
p2 = OrfFinder(meta=False, training_info=ti)
# make sure the same genes are found
g1 = p1.find_genes(record.seq)
g2 = p2.find_genes(record.seq)
# make sure genes are the same
self.assertEqual(len(g1), len(g2))
for gene1, gene2 in zip(g1, g2):
self.assertEqual(gene1.begin, gene2.begin)
self.assertEqual(gene1.end, gene2.end)
self.assertEqual(gene1.strand, gene2.strand)
self.assertEqual(gene1.partial_begin, gene2.partial_begin)
self.assertEqual(gene1.partial_end, gene2.partial_end)
self.assertEqual(gene1.start_type, gene2.start_type)
self.assertEqual(gene1.rbs_spacer, gene2.rbs_spacer)
self.assertEqual(gene1.gc_cont, gene2.gc_cont)
self.assertEqual(gene1.translation_table, gene2.translation_table)
self.assertEqual(gene1.cscore, gene2.cscore)
self.assertEqual(gene1.rscore, gene2.rscore)
self.assertEqual(gene1.sscore, gene2.sscore)
self.assertEqual(gene1.tscore, gene2.tscore)
self.assertEqual(gene1.uscore, gene2.uscore)
self.assertEqual(gene1.score, gene2.score)
......@@ -4,15 +4,25 @@ import os
import tempfile
import textwrap
import unittest
import pickle
import warnings
from .. import OrfFinder, TrainingInfo
from .._pyrodigal import METAGENOMIC_BINS
from .fasta import parse
from .utils import load_record
class TestTrainingInfo(unittest.TestCase):
def assertTrainingInfoEqual(self, t1, t2):
self.assertEqual(t1.translation_table, t2.translation_table)
self.assertEqual(t1.gc, t2.gc)
self.assertEqual(t1.bias, t2.bias)
self.assertEqual(t1.type_weights, t2.type_weights)
self.assertEqual(t1.uses_sd, t2.uses_sd)
self.assertEqual(t1.start_weight, t2.start_weight)
def setUpClass(cls):
with warnings.catch_warnings():
......@@ -30,13 +40,7 @@ class TestTrainingInfo(unittest.TestCase):
if os.path.exists(filename):
self.assertEqual(tinf.translation_table, self.training_info.translation_table)
self.assertEqual(tinf.gc, self.training_info.gc)
self.assertEqual(tinf.bias, self.training_info.bias)
self.assertEqual(tinf.type_weights, self.training_info.type_weights)
self.assertEqual(tinf.uses_sd, self.training_info.uses_sd)
self.assertEqual(tinf.start_weight, self.training_info.start_weight)
self.assertTrainingInfoEqual(tinf, self.training_info)
def test_load_error(self):
......@@ -48,3 +52,8 @@ class TestTrainingInfo(unittest.TestCase):
if os.path.exists(filename):
def test_pickle(self):
t1 = METAGENOMIC_BINS[0].training_info
t2 = pickle.loads(pickle.dumps(t1))
self.assertTrainingInfoEqual(t1, t2)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment