From 40c1cfedfb75a53eeb7a5723163afa8a2829c8bd Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Mon, 7 Oct 2024 17:59:28 +0200
Subject: [PATCH] Update `TopHits` to reference the query it was created with

---
 pyhmmer/plan7.pxd |  8 +++--
 pyhmmer/plan7.pyx | 83 +++++++++++++++++++++++------------------------
 2 files changed, 46 insertions(+), 45 deletions(-)

diff --git a/pyhmmer/plan7.pxd b/pyhmmer/plan7.pxd
index 137d2279..2b2e5977 100644
--- a/pyhmmer/plan7.pxd
+++ b/pyhmmer/plan7.pxd
@@ -408,9 +408,11 @@ cdef class TopHits:
     #                  computed and thresholding can be done correctly.
     cdef P7_PIPELINE _pli
     cdef P7_TOPHITS* _th
-    cdef bytes       _qname
-    cdef bytes       _qacc
-    cdef int         _qlen
+    cdef object      _query
+
+    # cdef bytes       _qname
+    # cdef bytes       _qacc
+    # cdef int         _qlen
 
     cdef int _threshold(self, Pipeline pipeline) except 1 nogil
     cdef int _sort_by_key(self) except 1 nogil
diff --git a/pyhmmer/plan7.pyx b/pyhmmer/plan7.pyx
index 22a1c19b..743866e4 100644
--- a/pyhmmer/plan7.pyx
+++ b/pyhmmer/plan7.pyx
@@ -5556,7 +5556,6 @@ cdef class Pipeline:
                 if self.profile._gm == NULL:
                     raise AllocationError("P7_PROFILE", sizeof(P7_OPROFILE))
             else:
-
                 self.profile.clear()
             # configure the profile from the query HMM
             self.profile.configure(<HMM> query, self.background, L)
@@ -5816,12 +5815,7 @@ cdef class Pipeline:
             hits._threshold(self)
 
         # record the query metadata
-        hits._qlen = om.M
-        if om.name != NULL:
-            hits._qname = PyBytes_FromString(om.name)
-        if om.acc != NULL:
-            hits._qacc = PyBytes_FromString(om.acc)
-
+        hits._query = query
         # return the hits
         return hits
 
@@ -5877,6 +5871,7 @@ cdef class Pipeline:
         cdef HMM              hmm
         cdef OptimizedProfile opt
         cdef Profile          profile
+        cdef TopHits          hits
 
         # check the pipeline was configured with the same alphabet
         if not self.alphabet._eq(query.alphabet):
@@ -5886,12 +5881,15 @@ cdef class Pipeline:
         # build the HMM and the profile from the query MSA
         hmm, profile, opt = builder.build_msa(query, self.background)
         if isinstance(sequences, DigitalSequenceBlock):
-            return self.search_hmm[DigitalSequenceBlock](opt, sequences)
+            hits = self.search_hmm[DigitalSequenceBlock](opt, sequences)
         elif isinstance(sequences, SequenceFile):
-            return self.search_hmm[SequenceFile](opt, sequences)
+            hits = self.search_hmm[SequenceFile](opt, sequences)
         else:
             ty = type(sequences).__name__
             raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}")
+        # record query metadata
+        hits._query = query
+        return hits        
 
     cpdef TopHits search_seq(
         self,
@@ -5938,6 +5936,7 @@ cdef class Pipeline:
         cdef HMM              hmm
         cdef OptimizedProfile opt
         cdef Profile          profile
+        cdef TopHits          hits
 
         # check the pipeline was configure with the same alphabet
         if not self.alphabet._eq(query.alphabet):
@@ -5949,12 +5948,15 @@ cdef class Pipeline:
         # build the HMM and the profile from the query sequence
         hmm, profile, opt = builder.build(query, self.background)
         if isinstance(sequences, DigitalSequenceBlock):
-            return self.search_hmm[DigitalSequenceBlock](opt, sequences)
+            hits = self.search_hmm[DigitalSequenceBlock](opt, sequences)
         elif isinstance(sequences, SequenceFile):
-            return self.search_hmm[SequenceFile](opt, sequences)
+            hits = self.search_hmm[SequenceFile](opt, sequences)
         else:
             ty = type(sequences).__name__
             raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}")
+        # record query metadata
+        hits._query = query
+        return hits
 
     @staticmethod
     cdef int _search_loop(
@@ -6182,12 +6184,7 @@ cdef class Pipeline:
             hits._threshold(self)
 
         # record the query metadata
-        hits._qlen = query._sq.L
-        if query._sq.name != NULL:
-            hits._qname = PyBytes_FromString(query._sq.name)
-        if query._sq.acc != NULL:
-            hits._qacc = PyBytes_FromString(query._sq.acc)
-
+        hits._query = query
         # return the hits
         return hits
 
@@ -6938,12 +6935,7 @@ cdef class LongTargetsPipeline(Pipeline):
                     hits._pli.pos_output += 1 + llabs(hits._th.hit[j].dcl[0].jali - hits._th.hit[j].dcl[0].iali)
 
         # record the query metadata
-        hits._qlen = om.M
-        if om.name != NULL:
-            hits._qname = PyBytes_FromString(om.name)
-        if om.acc != NULL:
-            hits._qacc = PyBytes_FromString(om.acc)
-
+        hits._query = query
         # return the hits
         return hits
 
@@ -6981,6 +6973,10 @@ cdef class LongTargetsPipeline(Pipeline):
 
         """
         assert self._pli != NULL
+
+        cdef HMM              hmm
+        cdef TopHits          hits
+
         if not self.alphabet._eq(query.alphabet):
             raise AlphabetMismatch(self.alphabet, query.alphabet)
 
@@ -6996,16 +6992,19 @@ cdef class LongTargetsPipeline(Pipeline):
         elif builder.window_beta != self.window_beta:
             raise ValueError("builder and long targets pipeline have different window beta")
 
-        cdef HMM hmm = builder.build(query, self.background)[0]
+        hmm = builder.build(query, self.background)[0]
         assert hmm._hmm.max_length != -1
         if isinstance(sequences, DigitalSequenceBlock):
-            return self.search_hmm[DigitalSequenceBlock](hmm, sequences)
+            hits = self.search_hmm[DigitalSequenceBlock](hmm, sequences)
         elif isinstance(sequences, SequenceFile):
-            return self.search_hmm[SequenceFile](hmm, sequences)
+            hits = self.search_hmm[SequenceFile](hmm, sequences)
         else:
             ty = type(sequences).__name__
             raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}")
 
+        hits._query = query
+        return hits
+
     cpdef TopHits search_msa(
         self,
         DigitalMSA query,
@@ -7042,6 +7041,9 @@ cdef class LongTargetsPipeline(Pipeline):
         """
         assert self._pli != NULL
 
+        cdef HMM              hmm
+        cdef TopHits          hits
+
         if not self.alphabet._eq(query.alphabet):
             raise AlphabetMismatch(self.alphabet, query.alphabet)
 
@@ -7057,15 +7059,18 @@ cdef class LongTargetsPipeline(Pipeline):
         elif builder.window_beta != self.window_beta:
             raise ValueError("builder and long targets pipeline have different window beta")
 
-        cdef HMM hmm = builder.build_msa(query, self.background)[0]
+        hmm = builder.build_msa(query, self.background)[0]
         assert hmm._hmm.max_length != -1
         if isinstance(sequences, DigitalSequenceBlock):
-            return self.search_hmm[DigitalSequenceBlock](hmm, sequences)
+            hits = self.search_hmm[DigitalSequenceBlock](hmm, sequences)
         elif isinstance(sequences, SequenceFile):
-            return self.search_hmm[SequenceFile](hmm, sequences)
+            hits = self.search_hmm[SequenceFile](hmm, sequences)
         else:
             ty = type(sequences).__name__
             raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}")
+        
+        hits._query = query
+        return hits
 
     @staticmethod
     cdef int _search_loop_longtargets(
@@ -7672,9 +7677,7 @@ cdef class TopHits:
 
     def __cinit__(self):
         self._th = NULL
-        self._qname = None
-        self._qacc = None
-        self._qlen = -1
+        self._query = None
         memset(&self._pli, 0, sizeof(P7_PIPELINE))
 
     def __init__(self):
@@ -7746,9 +7749,7 @@ cdef class TopHits:
             hits.append(offset)
 
         return {
-            "qname": self._qname,
-            "qacc": self._qacc,
-            "qlen": self._qlen,
+            "query": self.query,
             "unsrt": unsrt,
             "hit": hits,
             "Nalloc": self._th.Nalloc,
@@ -7814,9 +7815,7 @@ cdef class TopHits:
         cdef VectorU8 hit_state
 
         # record query name and accession
-        self._qname = state["qname"]
-        self._qacc = state["qacc"]
-        self._qlen = state["qlen"]
+        self.query = state["query"]
 
         # deallocate current data if needed
         if self._th != NULL:
@@ -7925,7 +7924,7 @@ cdef class TopHits:
         .. versionadded:: 0.6.1
 
         """
-        return self._qname
+        return self._query.name
 
     @property
     def query_accession(self):
@@ -7934,7 +7933,7 @@ cdef class TopHits:
         .. versionadded:: 0.6.1
 
         """
-        return self._qacc
+        return self._query.accession
 
     @property
     def query_length(self):
@@ -7943,7 +7942,7 @@ cdef class TopHits:
         .. versionadded:: 0.10.5
 
         """
-        return self._qlen
+        return self._query.M if isinstance(self._query, HMM) else len(self._query)
 
     @property
     def Z(self):
@@ -8546,7 +8545,7 @@ cdef class TopHits:
                 continue
 
             # check that names/accessions are consistent
-            if merged._qname != other._qname or merged._qacc != other._qacc or merged._qlen != other._qlen:
+            if merged._query != other._query:
                 raise ValueError("Trying to merge `TopHits` obtained from different queries")
 
             # check that the parameters are the same
-- 
GitLab