From 40c1cfedfb75a53eeb7a5723163afa8a2829c8bd Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Mon, 7 Oct 2024 17:59:28 +0200 Subject: [PATCH] Update `TopHits` to reference the query it was created with --- pyhmmer/plan7.pxd | 8 +++-- pyhmmer/plan7.pyx | 83 +++++++++++++++++++++++------------------------ 2 files changed, 46 insertions(+), 45 deletions(-) diff --git a/pyhmmer/plan7.pxd b/pyhmmer/plan7.pxd index 137d2279..2b2e5977 100644 --- a/pyhmmer/plan7.pxd +++ b/pyhmmer/plan7.pxd @@ -408,9 +408,11 @@ cdef class TopHits: # computed and thresholding can be done correctly. cdef P7_PIPELINE _pli cdef P7_TOPHITS* _th - cdef bytes _qname - cdef bytes _qacc - cdef int _qlen + cdef object _query + + # cdef bytes _qname + # cdef bytes _qacc + # cdef int _qlen cdef int _threshold(self, Pipeline pipeline) except 1 nogil cdef int _sort_by_key(self) except 1 nogil diff --git a/pyhmmer/plan7.pyx b/pyhmmer/plan7.pyx index 22a1c19b..743866e4 100644 --- a/pyhmmer/plan7.pyx +++ b/pyhmmer/plan7.pyx @@ -5556,7 +5556,6 @@ cdef class Pipeline: if self.profile._gm == NULL: raise AllocationError("P7_PROFILE", sizeof(P7_OPROFILE)) else: - self.profile.clear() # configure the profile from the query HMM self.profile.configure(<HMM> query, self.background, L) @@ -5816,12 +5815,7 @@ cdef class Pipeline: hits._threshold(self) # record the query metadata - hits._qlen = om.M - if om.name != NULL: - hits._qname = PyBytes_FromString(om.name) - if om.acc != NULL: - hits._qacc = PyBytes_FromString(om.acc) - + hits._query = query # return the hits return hits @@ -5877,6 +5871,7 @@ cdef class Pipeline: cdef HMM hmm cdef OptimizedProfile opt cdef Profile profile + cdef TopHits hits # check the pipeline was configured with the same alphabet if not self.alphabet._eq(query.alphabet): @@ -5886,12 +5881,15 @@ cdef class Pipeline: # build the HMM and the profile from the query MSA hmm, profile, opt = builder.build_msa(query, self.background) if isinstance(sequences, DigitalSequenceBlock): - return self.search_hmm[DigitalSequenceBlock](opt, sequences) + hits = self.search_hmm[DigitalSequenceBlock](opt, sequences) elif isinstance(sequences, SequenceFile): - return self.search_hmm[SequenceFile](opt, sequences) + hits = self.search_hmm[SequenceFile](opt, sequences) else: ty = type(sequences).__name__ raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}") + # record query metadata + hits._query = query + return hits cpdef TopHits search_seq( self, @@ -5938,6 +5936,7 @@ cdef class Pipeline: cdef HMM hmm cdef OptimizedProfile opt cdef Profile profile + cdef TopHits hits # check the pipeline was configure with the same alphabet if not self.alphabet._eq(query.alphabet): @@ -5949,12 +5948,15 @@ cdef class Pipeline: # build the HMM and the profile from the query sequence hmm, profile, opt = builder.build(query, self.background) if isinstance(sequences, DigitalSequenceBlock): - return self.search_hmm[DigitalSequenceBlock](opt, sequences) + hits = self.search_hmm[DigitalSequenceBlock](opt, sequences) elif isinstance(sequences, SequenceFile): - return self.search_hmm[SequenceFile](opt, sequences) + hits = self.search_hmm[SequenceFile](opt, sequences) else: ty = type(sequences).__name__ raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}") + # record query metadata + hits._query = query + return hits @staticmethod cdef int _search_loop( @@ -6182,12 +6184,7 @@ cdef class Pipeline: hits._threshold(self) # record the query metadata - hits._qlen = query._sq.L - if query._sq.name != NULL: - hits._qname = PyBytes_FromString(query._sq.name) - if query._sq.acc != NULL: - hits._qacc = PyBytes_FromString(query._sq.acc) - + hits._query = query # return the hits return hits @@ -6938,12 +6935,7 @@ cdef class LongTargetsPipeline(Pipeline): hits._pli.pos_output += 1 + llabs(hits._th.hit[j].dcl[0].jali - hits._th.hit[j].dcl[0].iali) # record the query metadata - hits._qlen = om.M - if om.name != NULL: - hits._qname = PyBytes_FromString(om.name) - if om.acc != NULL: - hits._qacc = PyBytes_FromString(om.acc) - + hits._query = query # return the hits return hits @@ -6981,6 +6973,10 @@ cdef class LongTargetsPipeline(Pipeline): """ assert self._pli != NULL + + cdef HMM hmm + cdef TopHits hits + if not self.alphabet._eq(query.alphabet): raise AlphabetMismatch(self.alphabet, query.alphabet) @@ -6996,16 +6992,19 @@ cdef class LongTargetsPipeline(Pipeline): elif builder.window_beta != self.window_beta: raise ValueError("builder and long targets pipeline have different window beta") - cdef HMM hmm = builder.build(query, self.background)[0] + hmm = builder.build(query, self.background)[0] assert hmm._hmm.max_length != -1 if isinstance(sequences, DigitalSequenceBlock): - return self.search_hmm[DigitalSequenceBlock](hmm, sequences) + hits = self.search_hmm[DigitalSequenceBlock](hmm, sequences) elif isinstance(sequences, SequenceFile): - return self.search_hmm[SequenceFile](hmm, sequences) + hits = self.search_hmm[SequenceFile](hmm, sequences) else: ty = type(sequences).__name__ raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}") + hits._query = query + return hits + cpdef TopHits search_msa( self, DigitalMSA query, @@ -7042,6 +7041,9 @@ cdef class LongTargetsPipeline(Pipeline): """ assert self._pli != NULL + cdef HMM hmm + cdef TopHits hits + if not self.alphabet._eq(query.alphabet): raise AlphabetMismatch(self.alphabet, query.alphabet) @@ -7057,15 +7059,18 @@ cdef class LongTargetsPipeline(Pipeline): elif builder.window_beta != self.window_beta: raise ValueError("builder and long targets pipeline have different window beta") - cdef HMM hmm = builder.build_msa(query, self.background)[0] + hmm = builder.build_msa(query, self.background)[0] assert hmm._hmm.max_length != -1 if isinstance(sequences, DigitalSequenceBlock): - return self.search_hmm[DigitalSequenceBlock](hmm, sequences) + hits = self.search_hmm[DigitalSequenceBlock](hmm, sequences) elif isinstance(sequences, SequenceFile): - return self.search_hmm[SequenceFile](hmm, sequences) + hits = self.search_hmm[SequenceFile](hmm, sequences) else: ty = type(sequences).__name__ raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}") + + hits._query = query + return hits @staticmethod cdef int _search_loop_longtargets( @@ -7672,9 +7677,7 @@ cdef class TopHits: def __cinit__(self): self._th = NULL - self._qname = None - self._qacc = None - self._qlen = -1 + self._query = None memset(&self._pli, 0, sizeof(P7_PIPELINE)) def __init__(self): @@ -7746,9 +7749,7 @@ cdef class TopHits: hits.append(offset) return { - "qname": self._qname, - "qacc": self._qacc, - "qlen": self._qlen, + "query": self.query, "unsrt": unsrt, "hit": hits, "Nalloc": self._th.Nalloc, @@ -7814,9 +7815,7 @@ cdef class TopHits: cdef VectorU8 hit_state # record query name and accession - self._qname = state["qname"] - self._qacc = state["qacc"] - self._qlen = state["qlen"] + self.query = state["query"] # deallocate current data if needed if self._th != NULL: @@ -7925,7 +7924,7 @@ cdef class TopHits: .. versionadded:: 0.6.1 """ - return self._qname + return self._query.name @property def query_accession(self): @@ -7934,7 +7933,7 @@ cdef class TopHits: .. versionadded:: 0.6.1 """ - return self._qacc + return self._query.accession @property def query_length(self): @@ -7943,7 +7942,7 @@ cdef class TopHits: .. versionadded:: 0.10.5 """ - return self._qlen + return self._query.M if isinstance(self._query, HMM) else len(self._query) @property def Z(self): @@ -8546,7 +8545,7 @@ cdef class TopHits: continue # check that names/accessions are consistent - if merged._qname != other._qname or merged._qacc != other._qacc or merged._qlen != other._qlen: + if merged._query != other._query: raise ValueError("Trying to merge `TopHits` obtained from different queries") # check that the parameters are the same -- GitLab