Commit 069d959a authored by Martin Larralde's avatar Martin Larralde
Browse files

Implement pickle protocol for `pyrodigal.OrfFinder`

parent 382787a9
......@@ -327,6 +327,9 @@ cdef class OrfFinder:
cdef readonly int max_overlap
cdef readonly TrainingInfo training_info
cpdef dict __getstate__(self)
cpdef object __setstate__(self, dict state)
cdef int _train(
self,
Sequence sequence,
......
......@@ -247,6 +247,8 @@ class OrfFinder:
max_overlap: int = 60,
) -> None: ...
def __repr__(self) -> str: ...
def __getstate__(self) -> Dict[str, object]: ...
def __setstate__(self, state: Dict[str, object]) -> None: ...
@property
def training_info(self) -> Optional[TrainingInfo]: ...
@property
......
......@@ -4387,6 +4387,29 @@ cdef class OrfFinder:
ty = type(self)
return "{}.{}({})".format(ty.__module__, ty.__name__, ", ".join(template))
cpdef dict __getstate__(self):
return {
"_num_seq": self._num_seq,
"closed": self.closed,
"meta": self.meta,
"mask": self.mask,
"min_gene": self.min_gene,
"min_edge_gene": self.min_edge_gene,
"max_overlap": self.max_overlap,
"training_info": self.training_info
}
cpdef object __setstate__(self, dict state):
self.lock = threading.Lock()
self._num_seq = state["_num_seq"]
self.closed = state["closed"]
self.meta = state["meta"]
self.mask = state["mask"]
self.min_gene = state["min_gene"]
self.min_edge_gene = state["min_edge_gene"]
self.max_overlap = state["max_overlap"]
self.training_info = state["training_info"]
# --- C interface --------------------------------------------------------
cdef int _train(
......
......@@ -12,6 +12,23 @@ from .utils import load_record, load_proteins, load_genes
class _OrfFinderTestCase(object):
def assertGeneEqual(self, gene1, gene2):
self.assertEqual(gene1.begin, gene2.begin)
self.assertEqual(gene1.end, gene2.end)
self.assertEqual(gene1.strand, gene2.strand)
self.assertEqual(gene1.partial_begin, gene2.partial_begin)
self.assertEqual(gene1.partial_end, gene2.partial_end)
self.assertEqual(gene1.start_type, gene2.start_type)
self.assertEqual(gene1.rbs_spacer, gene2.rbs_spacer)
self.assertEqual(gene1.gc_cont, gene2.gc_cont)
self.assertEqual(gene1.translation_table, gene2.translation_table)
self.assertEqual(gene1.cscore, gene2.cscore)
self.assertEqual(gene1.rscore, gene2.rscore)
self.assertEqual(gene1.sscore, gene2.sscore)
self.assertEqual(gene1.tscore, gene2.tscore)
self.assertEqual(gene1.uscore, gene2.uscore)
self.assertEqual(gene1.score, gene2.score)
def assertTranslationsEqual(self, predictions, proteins):
self.assertEqual(len(predictions), len(proteins))
for pred, protein in zip(predictions, proteins):
......@@ -333,34 +350,37 @@ class TestSingle(_OrfFinderTestCase, unittest.TestCase):
self.assertEqual(len(genes), 0)
self.assertRaises(StopIteration, next, iter(genes))
def test_training_info_pickle(self):
def test_pickle(self):
record = load_record("SRR492066")
# train separately
p1 = OrfFinder(meta=False)
p1 = OrfFinder(meta=False, min_gene=60)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
p1.train(str(record.seq[:20000]))
# pickle/unpickle the OrfFinder
p2 = pickle.loads(pickle.dumps(p1))
# make sure the same genes are found
g1 = p1.find_genes(record.seq)
g2 = p2.find_genes(record.seq)
# make sure genes are the same
self.assertEqual(len(g1), len(g2))
for gene1, gene2 in zip(g1, g2):
self.assertGeneEqual(gene1, gene2)
def test_training_info_pickle(self):
record = load_record("SRR492066")
# train separately
p1 = OrfFinder(meta=False, min_gene=60)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
p1.train(str(record.seq[:20000]))
# pickle/unpickle the TrainingInfo
ti = pickle.loads(pickle.dumps(p1.training_info))
p2 = OrfFinder(meta=False, training_info=ti)
p2 = OrfFinder(meta=False, training_info=ti, min_gene=60)
# make sure the same genes are found
g1 = p1.find_genes(record.seq)
g2 = p2.find_genes(record.seq)
# make sure genes are the same
self.assertEqual(len(g1), len(g2))
for gene1, gene2 in zip(g1, g2):
self.assertEqual(gene1.begin, gene2.begin)
self.assertEqual(gene1.end, gene2.end)
self.assertEqual(gene1.strand, gene2.strand)
self.assertEqual(gene1.partial_begin, gene2.partial_begin)
self.assertEqual(gene1.partial_end, gene2.partial_end)
self.assertEqual(gene1.start_type, gene2.start_type)
self.assertEqual(gene1.rbs_spacer, gene2.rbs_spacer)
self.assertEqual(gene1.gc_cont, gene2.gc_cont)
self.assertEqual(gene1.translation_table, gene2.translation_table)
self.assertEqual(gene1.cscore, gene2.cscore)
self.assertEqual(gene1.rscore, gene2.rscore)
self.assertEqual(gene1.sscore, gene2.sscore)
self.assertEqual(gene1.tscore, gene2.tscore)
self.assertEqual(gene1.uscore, gene2.uscore)
self.assertEqual(gene1.score, gene2.score)
self.assertGeneEqual(gene1, gene2)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment