Skip to content
Snippets Groups Projects
Commit 40c1cfed authored by Martin Larralde's avatar Martin Larralde
Browse files

Update `TopHits` to reference the query it was created with

parent e8f254c8
No related branches found
No related tags found
No related merge requests found
......@@ -408,9 +408,11 @@ cdef class TopHits:
# computed and thresholding can be done correctly.
cdef P7_PIPELINE _pli
cdef P7_TOPHITS* _th
cdef bytes _qname
cdef bytes _qacc
cdef int _qlen
cdef object _query
# cdef bytes _qname
# cdef bytes _qacc
# cdef int _qlen
cdef int _threshold(self, Pipeline pipeline) except 1 nogil
cdef int _sort_by_key(self) except 1 nogil
......
......@@ -5556,7 +5556,6 @@ cdef class Pipeline:
if self.profile._gm == NULL:
raise AllocationError("P7_PROFILE", sizeof(P7_OPROFILE))
else:
self.profile.clear()
# configure the profile from the query HMM
self.profile.configure(<HMM> query, self.background, L)
......@@ -5816,12 +5815,7 @@ cdef class Pipeline:
hits._threshold(self)
# record the query metadata
hits._qlen = om.M
if om.name != NULL:
hits._qname = PyBytes_FromString(om.name)
if om.acc != NULL:
hits._qacc = PyBytes_FromString(om.acc)
hits._query = query
# return the hits
return hits
......@@ -5877,6 +5871,7 @@ cdef class Pipeline:
cdef HMM hmm
cdef OptimizedProfile opt
cdef Profile profile
cdef TopHits hits
# check the pipeline was configured with the same alphabet
if not self.alphabet._eq(query.alphabet):
......@@ -5886,12 +5881,15 @@ cdef class Pipeline:
# build the HMM and the profile from the query MSA
hmm, profile, opt = builder.build_msa(query, self.background)
if isinstance(sequences, DigitalSequenceBlock):
return self.search_hmm[DigitalSequenceBlock](opt, sequences)
hits = self.search_hmm[DigitalSequenceBlock](opt, sequences)
elif isinstance(sequences, SequenceFile):
return self.search_hmm[SequenceFile](opt, sequences)
hits = self.search_hmm[SequenceFile](opt, sequences)
else:
ty = type(sequences).__name__
raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}")
# record query metadata
hits._query = query
return hits
cpdef TopHits search_seq(
self,
......@@ -5938,6 +5936,7 @@ cdef class Pipeline:
cdef HMM hmm
cdef OptimizedProfile opt
cdef Profile profile
cdef TopHits hits
# check the pipeline was configure with the same alphabet
if not self.alphabet._eq(query.alphabet):
......@@ -5949,12 +5948,15 @@ cdef class Pipeline:
# build the HMM and the profile from the query sequence
hmm, profile, opt = builder.build(query, self.background)
if isinstance(sequences, DigitalSequenceBlock):
return self.search_hmm[DigitalSequenceBlock](opt, sequences)
hits = self.search_hmm[DigitalSequenceBlock](opt, sequences)
elif isinstance(sequences, SequenceFile):
return self.search_hmm[SequenceFile](opt, sequences)
hits = self.search_hmm[SequenceFile](opt, sequences)
else:
ty = type(sequences).__name__
raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}")
# record query metadata
hits._query = query
return hits
@staticmethod
cdef int _search_loop(
......@@ -6182,12 +6184,7 @@ cdef class Pipeline:
hits._threshold(self)
# record the query metadata
hits._qlen = query._sq.L
if query._sq.name != NULL:
hits._qname = PyBytes_FromString(query._sq.name)
if query._sq.acc != NULL:
hits._qacc = PyBytes_FromString(query._sq.acc)
hits._query = query
# return the hits
return hits
......@@ -6938,12 +6935,7 @@ cdef class LongTargetsPipeline(Pipeline):
hits._pli.pos_output += 1 + llabs(hits._th.hit[j].dcl[0].jali - hits._th.hit[j].dcl[0].iali)
# record the query metadata
hits._qlen = om.M
if om.name != NULL:
hits._qname = PyBytes_FromString(om.name)
if om.acc != NULL:
hits._qacc = PyBytes_FromString(om.acc)
hits._query = query
# return the hits
return hits
......@@ -6981,6 +6973,10 @@ cdef class LongTargetsPipeline(Pipeline):
"""
assert self._pli != NULL
cdef HMM hmm
cdef TopHits hits
if not self.alphabet._eq(query.alphabet):
raise AlphabetMismatch(self.alphabet, query.alphabet)
......@@ -6996,16 +6992,19 @@ cdef class LongTargetsPipeline(Pipeline):
elif builder.window_beta != self.window_beta:
raise ValueError("builder and long targets pipeline have different window beta")
cdef HMM hmm = builder.build(query, self.background)[0]
hmm = builder.build(query, self.background)[0]
assert hmm._hmm.max_length != -1
if isinstance(sequences, DigitalSequenceBlock):
return self.search_hmm[DigitalSequenceBlock](hmm, sequences)
hits = self.search_hmm[DigitalSequenceBlock](hmm, sequences)
elif isinstance(sequences, SequenceFile):
return self.search_hmm[SequenceFile](hmm, sequences)
hits = self.search_hmm[SequenceFile](hmm, sequences)
else:
ty = type(sequences).__name__
raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}")
hits._query = query
return hits
cpdef TopHits search_msa(
self,
DigitalMSA query,
......@@ -7042,6 +7041,9 @@ cdef class LongTargetsPipeline(Pipeline):
"""
assert self._pli != NULL
cdef HMM hmm
cdef TopHits hits
if not self.alphabet._eq(query.alphabet):
raise AlphabetMismatch(self.alphabet, query.alphabet)
......@@ -7057,15 +7059,18 @@ cdef class LongTargetsPipeline(Pipeline):
elif builder.window_beta != self.window_beta:
raise ValueError("builder and long targets pipeline have different window beta")
cdef HMM hmm = builder.build_msa(query, self.background)[0]
hmm = builder.build_msa(query, self.background)[0]
assert hmm._hmm.max_length != -1
if isinstance(sequences, DigitalSequenceBlock):
return self.search_hmm[DigitalSequenceBlock](hmm, sequences)
hits = self.search_hmm[DigitalSequenceBlock](hmm, sequences)
elif isinstance(sequences, SequenceFile):
return self.search_hmm[SequenceFile](hmm, sequences)
hits = self.search_hmm[SequenceFile](hmm, sequences)
else:
ty = type(sequences).__name__
raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}")
hits._query = query
return hits
@staticmethod
cdef int _search_loop_longtargets(
......@@ -7672,9 +7677,7 @@ cdef class TopHits:
def __cinit__(self):
self._th = NULL
self._qname = None
self._qacc = None
self._qlen = -1
self._query = None
memset(&self._pli, 0, sizeof(P7_PIPELINE))
def __init__(self):
......@@ -7746,9 +7749,7 @@ cdef class TopHits:
hits.append(offset)
return {
"qname": self._qname,
"qacc": self._qacc,
"qlen": self._qlen,
"query": self.query,
"unsrt": unsrt,
"hit": hits,
"Nalloc": self._th.Nalloc,
......@@ -7814,9 +7815,7 @@ cdef class TopHits:
cdef VectorU8 hit_state
# record query name and accession
self._qname = state["qname"]
self._qacc = state["qacc"]
self._qlen = state["qlen"]
self.query = state["query"]
# deallocate current data if needed
if self._th != NULL:
......@@ -7925,7 +7924,7 @@ cdef class TopHits:
.. versionadded:: 0.6.1
"""
return self._qname
return self._query.name
@property
def query_accession(self):
......@@ -7934,7 +7933,7 @@ cdef class TopHits:
.. versionadded:: 0.6.1
"""
return self._qacc
return self._query.accession
@property
def query_length(self):
......@@ -7943,7 +7942,7 @@ cdef class TopHits:
.. versionadded:: 0.10.5
"""
return self._qlen
return self._query.M if isinstance(self._query, HMM) else len(self._query)
@property
def Z(self):
......@@ -8546,7 +8545,7 @@ cdef class TopHits:
continue
# check that names/accessions are consistent
if merged._qname != other._qname or merged._qacc != other._qacc or merged._qlen != other._qlen:
if merged._query != other._query:
raise ValueError("Trying to merge `TopHits` obtained from different queries")
# check that the parameters are the same
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment