From 41c377e648babcb511ae1f6baf16c98496564f5a Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Tue, 8 Oct 2024 15:25:34 +0200
Subject: [PATCH] Fix handling of `TopHits` in `pyhmmer.daemon.Client`

---
 pyhmmer/daemon.pxd |  2 +-
 pyhmmer/daemon.pyx | 51 +++++++++++-----------------------------------
 2 files changed, 13 insertions(+), 40 deletions(-)

diff --git a/pyhmmer/daemon.pxd b/pyhmmer/daemon.pxd
index 199c6159..4aa1f2cd 100644
--- a/pyhmmer/daemon.pxd
+++ b/pyhmmer/daemon.pxd
@@ -22,7 +22,7 @@ cdef class Client:
     cdef bytearray _recvall(self, size_t message_size)
     cdef TopHits _client(
         self,
-        bytes query,
+        object query,
         uint64_t db,
         list ranges,
         Pipeline pli,
diff --git a/pyhmmer/daemon.pyx b/pyhmmer/daemon.pyx
index d832cf0c..c161e5ca 100644
--- a/pyhmmer/daemon.pyx
+++ b/pyhmmer/daemon.pyx
@@ -160,7 +160,7 @@ cdef class Client:
 
     cdef TopHits _client(
         self,
-        bytes query,
+        object query,
         uint64_t db,
         list ranges,
         Pipeline pli,
@@ -190,7 +190,7 @@ cdef class Client:
 
         cdef uint32_t           hits_start
         cdef uint32_t           buf_offset    = 0
-        cdef TopHits            hits          = TopHits()
+        cdef TopHits            hits          = TopHits(query)
         cdef str                options       = "".join(pli.arguments())
 
         # check ranges argument
@@ -207,6 +207,12 @@ cdef class Client:
         memset(&search_status, 0, sizeof(HMMD_SEARCH_STATUS))
         search_stats.hit_offsets = NULL
 
+        # serialize query
+        with io.BytesIO() as buffer:
+            query.write(buffer)
+            buffer.write(b"\n//")
+            txt = buffer.getvalue()
+
         try:
             # send the options
             if mode == p7_pipemodes_e.p7_SEARCH_SEQS:
@@ -220,7 +226,7 @@ cdef class Client:
                 self.socket.sendall(f"@--hmmdb {db} {options}\n".encode("ascii"))
 
             # send the query
-            self.socket.sendall(query)
+            self.socket.sendall(txt)
 
             # get the search status back
             response = self._recvall(HMMD_SEARCH_STATUS_SERIAL_SIZE)
@@ -344,20 +350,9 @@ cdef class Client:
             sequence against the sequence database loaded on the server side.
 
         """
-        cdef bytes    txt
-        cdef TopHits  hits
         cdef Alphabet abc  = getattr(query, "alphabet", Alphabet.amino())
         cdef Pipeline pli  = Pipeline(abc, **options)
-
-        with io.BytesIO() as buffer:
-            query.write(buffer)
-            buffer.write(b"\n//")
-            txt = buffer.getvalue()
-
-        hits = self._client(txt, db, ranges, pli, p7_pipemodes_e.p7_SEARCH_SEQS)
-        hits._query = query
-
-        return hits
+        return self._client(query, db, ranges, pli, p7_pipemodes_e.p7_SEARCH_SEQS)
 
     def search_hmm(
         self,
@@ -384,19 +379,8 @@ cdef class Client:
             server side.
 
         """
-        cdef bytes    txt
-        cdef TopHits  hits
         cdef Pipeline pli  = Pipeline(query.alphabet, **options)
-
-        with io.BytesIO() as buffer:
-            query.write(buffer)
-            buffer.write(b"\n//")
-            txt = buffer.getvalue()
-
-        hits = self._client(txt, db, ranges, pli, p7_pipemodes_e.p7_SEARCH_SEQS)
-        hits._query = query
-
-        return hits
+        return self._client(query, db, ranges, pli, p7_pipemodes_e.p7_SEARCH_SEQS)
 
     def scan_seq(self, Sequence query, uint64_t db = 1, **options):
         """Search the HMMER daemon database with a query sequence.
@@ -415,20 +399,9 @@ cdef class Client:
             server side.
 
         """
-        cdef bytes    txt
-        cdef TopHits  hits
         cdef Alphabet abc  = getattr(query, "alphabet", Alphabet.amino())
         cdef Pipeline pli  = Pipeline(abc, **options)
-
-        with io.BytesIO() as buffer:
-            query.write(buffer)
-            buffer.write(b"\n//")
-            txt = buffer.getvalue()
-
-        hits = self._client(txt, db, None, pli, p7_pipemodes_e.p7_SCAN_MODELS)
-        hits._query = query
-
-        return hits
+        return self._client(query, db, None, pli, p7_pipemodes_e.p7_SCAN_MODELS)
 
     def iterate_seq(
         self,
-- 
GitLab