Commit 7d4df801 authored by Martin Larralde's avatar Martin Larralde
Browse files

Implement pickle protocol for `plan7.TopHits` using HMMER serialization

parent e5cc1874
......@@ -307,6 +307,7 @@ cdef class TopHits:
cdef P7_TOPHITS* _th
cpdef dict __getstate__(self)
cpdef object __setstate__(self, dict state)
cdef int _threshold(self, Pipeline pipeline) nogil except 1
cdef int _sort_by_key(self) nogil except 1
......
......@@ -715,6 +715,7 @@ class TopHits(typing.Sequence[Hit]):
def __getitem__(self, index: slice) -> typing.Sequence[Hit]: ...
def __iadd__(self, other: TopHits) -> TopHits: ...
def __getstate__(self) -> typing.Dict[str, object]: ...
def __setstate__(self, state: typing.Dict[str, object]) -> None: ...
@property
def Z(self) -> float: ...
@property
......
......@@ -6328,9 +6328,26 @@ cdef class TopHits:
cpdef dict __getstate__(self):
assert self._th != NULL
cdef Hit hit
cdef size_t i
cdef ptrdiff_t offset
cdef Hit hit = Hit.__new__(Hit, self, 0)
cdef list unsrt = []
cdef list hits = []
hit.hits = self
for i in range(self._th.N):
hit._hit = &self._th.unsrt[i]
state = hit.__getstate__()
unsrt.append(state)
for i in range(self._th.N):
offset = (<ptrdiff_t> self._th.hit[i] - <ptrdiff_t> &self._th.unsrt[0]) / sizeof(P7_HIT)
hits.append(offset)
return {
"hits": [ hit.__getstate__() for hit in self ],
"unsrt": unsrt,
"hit": hits,
"Nalloc": self._th.Nalloc,
"N": self._th.N,
"nreported": self._th.nreported,
......@@ -6386,6 +6403,102 @@ cdef class TopHits:
}
}
cpdef object __setstate__(self, dict state):
cdef int status
cdef size_t i
cdef uint32_t n
cdef size_t offset
cdef VectorU8 hit_state
# deallocate current data if needed
if self._th != NULL:
libhmmer.p7_tophits.p7_tophits_Destroy(self._th)
self._th = NULL
# allocate a new `P7_TOPHITS` but using exact-sized buffers from the state
self._th = <P7_TOPHITS*> malloc(sizeof(P7_TOPHITS))
if self._th == NULL:
raise AllocationError("P7_TOPHITS", sizeof(P7_TOPHITS))
# copy numbers
self._th.N = self._th.Nalloc = state["N"]
self._th.nreported = state["nreported"]
self._th.nincluded = state["nincluded"]
self._th.is_sorted_by_seqidx = state["is_sorted_by_seqidx"]
self._th.is_sorted_by_sortkey = state["is_sorted_by_sortkey"]
# allocate memory for hits
self._th.unsrt = <P7_HIT*> calloc(state["N"], sizeof(P7_HIT))
if self._th.unsrt == NULL:
raise AllocationError("P7_HIT", sizeof(P7_HIT), state["N"])
self._th.hit = <P7_HIT**> calloc(state["N"], sizeof(P7_HIT*))
if self._th.hit == NULL:
raise AllocationError("P7_HIT*", sizeof(P7_HIT*), state["N"])
# setup sorted array
assert len(state["hit"]) == self._th.N
for i, offset in enumerate(state["hit"]):
self._th.hit[i] = &self._th.unsrt[offset]
# deserialize hits
assert len(state["unsrt"]) == self._th.N
for i, hit_state in enumerate(state["unsrt"]):
n = 0
status = libhmmer.p7_hit.p7_hit_Deserialize(
<const uint8_t*> hit_state._data,
&n,
&self._th.unsrt[i],
)
if status != libeasel.eslOK:
raise UnexpectedError(status, "p7_hit_Deserialize")
# copy pipeline configuration
self._pli.by_E = state["pipeline"]["by_E"]
self._pli.E = state["pipeline"]["E"]
self._pli.T = state["pipeline"]["T"]
self._pli.dom_by_E = state["pipeline"]["dom_by_E"]
self._pli.domE = state["pipeline"]["domE"]
self._pli.domT = state["pipeline"]["domT"]
self._pli.use_bit_cutoffs = state["pipeline"]["use_bit_cutoffs"]
self._pli.inc_by_E = state["pipeline"]["inc_by_E"]
self._pli.incE = state["pipeline"]["incE"]
self._pli.incT = state["pipeline"]["incT"]
self._pli.incdom_by_E = state["pipeline"]["incdom_by_E"]
self._pli.incdomE = state["pipeline"]["incdomE"]
self._pli.incdomT = state["pipeline"]["incdomT"]
self._pli.Z = state["pipeline"]["Z"]
self._pli.domZ = state["pipeline"]["domZ"]
self._pli.Z_setby = state["pipeline"]["Z_setby"]
self._pli.domZ_setby = state["pipeline"]["domZ_setby"]
self._pli.do_max = state["pipeline"]["do_max"]
self._pli.F1 = state["pipeline"]["F1"]
self._pli.F2 = state["pipeline"]["F2"]
self._pli.F3 = state["pipeline"]["F3"]
self._pli.B1 = state["pipeline"]["B1"]
self._pli.B2 = state["pipeline"]["B2"]
self._pli.B3 = state["pipeline"]["B3"]
self._pli.do_biasfilter = state["pipeline"]["do_biasfilter"]
self._pli.do_null2 = state["pipeline"]["do_null2"]
self._pli.nmodels = state["pipeline"]["nmodels"]
self._pli.nseqs = state["pipeline"]["nseqs"]
self._pli.nres = state["pipeline"]["nres"]
self._pli.nnodes = state["pipeline"]["nnodes"]
self._pli.n_past_msv = state["pipeline"]["n_past_msv"]
self._pli.n_past_bias = state["pipeline"]["n_past_bias"]
self._pli.n_past_vit = state["pipeline"]["n_past_vit"]
self._pli.n_past_fwd = state["pipeline"]["n_past_fwd"]
self._pli.n_output = state["pipeline"]["n_output"]
self._pli.pos_past_msv = state["pipeline"]["pos_past_msv"]
self._pli.pos_past_bias = state["pipeline"]["pos_past_bias"]
self._pli.pos_past_vit = state["pipeline"]["pos_past_vit"]
self._pli.pos_past_fwd = state["pipeline"]["pos_past_fwd"]
self._pli.pos_output = state["pipeline"]["pos_output"]
self._pli.mode = state["pipeline"]["mode"]
self._pli.long_targets = state["pipeline"]["long_targets"]
self._pli.strands = state["pipeline"]["strands"]
self._pli.W = state["pipeline"]["W"]
self._pli.block_length = state["pipeline"]["block_length"]
# --- Properties ---------------------------------------------------------
@property
......
import io
import itertools
import os
import pickle
import shutil
import unittest
import tempfile
......@@ -242,3 +243,7 @@ class TestTopHits(unittest.TestCase):
all_consensus_cols=True
)
self.assertIsInstance(msa_d, DigitalMSA)
def test_pickle(self):
pickled = pickle.loads(pickle.dumps(self.hits))
self.assertHitsEqual(pickled, self.hits)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment