From 21600453d707277b5fbcd4afa4ba0b8c61832824 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Tue, 8 Oct 2024 21:05:04 +0200
Subject: [PATCH] Add `SequenceFile` as an accepted type for most `hmmer`
 functions

---
 pyhmmer/hmmer.py | 50 +++++++++++++++++++++++-------------------------
 1 file changed, 24 insertions(+), 26 deletions(-)

diff --git a/pyhmmer/hmmer.py b/pyhmmer/hmmer.py
index fa1e1fe9..a3835616 100644
--- a/pyhmmer/hmmer.py
+++ b/pyhmmer/hmmer.py
@@ -809,7 +809,7 @@ class _SCANDispatcher(
 
 def hmmsearch(
     queries: typing.Union[_P, typing.Iterable[_P]],
-    sequences: typing.Iterable[DigitalSequence],
+    sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
     *,
     cpus: int = 0,
     callback: typing.Optional[typing.Callable[[_P, int], None]] = None,
@@ -884,14 +884,14 @@ def hmmsearch(
         queries = (queries,)
 
     if isinstance(sequences, SequenceFile):
-        sequence_file: SequenceFile = sequences
-        if sequence_file.name is None:
+        # sequence_file = sequences
+        if sequences.name is None:
             raise ValueError("expected named `SequenceFile` for targets")
-        if not sequence_file.digital:
+        if not sequences.digital:
             raise ValueError("expected digital mode `SequenceFile` for targets")
-        assert sequence_file.alphabet is not None
-        alphabet = alphabet or sequence_file.alphabet
-        targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequence_file
+        assert sequences.alphabet is not None
+        alphabet = alphabet or sequences.alphabet
+        targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequences
     elif isinstance(sequences, DigitalSequenceBlock):
         alphabet = alphabet or sequences.alphabet
         targets = sequences
@@ -923,7 +923,7 @@ def hmmsearch(
 
 def phmmer(
     queries: typing.Union[_M, typing.Iterable[_M]],
-    sequences: typing.Iterable[DigitalSequence],
+    sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
     *,
     cpus: int = 0,
     callback: typing.Optional[typing.Callable[[_M, int], None]] = None,
@@ -1023,7 +1023,7 @@ def phmmer(
 @typing.overload
 def jackhmmer(
     queries: typing.Union[_JACKHMMERQueryType, typing.Iterable[_JACKHMMERQueryType]],
-    sequences: typing.Iterable[DigitalSequence],
+    sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
     *,
     max_iterations: typing.Optional[int] = 5,
     select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None,
@@ -1039,7 +1039,7 @@ def jackhmmer(
 @typing.overload
 def jackhmmer(
     queries: typing.Union[_JACKHMMERQueryType, typing.Iterable[_JACKHMMERQueryType]],
-    sequences: typing.Iterable[DigitalSequence],
+    sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
     *,
     max_iterations: typing.Optional[int] = 5,
     select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None,
@@ -1055,7 +1055,7 @@ def jackhmmer(
 @typing.overload
 def jackhmmer(
     queries: typing.Union[_JACKHMMERQueryType, typing.Iterable[_JACKHMMERQueryType]],
-    sequences: typing.Iterable[DigitalSequence],
+    sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
     *,
     max_iterations: typing.Optional[int] = 5,
     select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None,
@@ -1072,7 +1072,7 @@ def jackhmmer(
 
 def jackhmmer(
     queries: typing.Union[_JACKHMMERQueryType, typing.Iterable[_JACKHMMERQueryType]],
-    sequences: typing.Iterable[DigitalSequence],
+    sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
     *,
     max_iterations: typing.Optional[int] = 5,
     select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None,
@@ -1157,14 +1157,13 @@ def jackhmmer(
         queries = (queries,)
 
     if isinstance(sequences, SequenceFile):
-        sequence_file: SequenceFile = sequences
-        if sequence_file.name is None:
+        if sequences.name is None:
             raise ValueError("expected named `SequenceFile` for targets")
-        if not sequence_file.digital:
+        if not sequences.digital:
             raise ValueError("expected digital mode `SequenceFile` for targets")
-        assert sequence_file.alphabet is not None
-        alphabet = alphabet or sequence_file.alphabet
-        targets = typing.cast(DigitalSequenceBlock, sequence_file.read_block())
+        assert sequences.alphabet is not None
+        alphabet = alphabet or sequences.alphabet
+        targets = typing.cast(DigitalSequenceBlock, sequences.read_block())
     elif isinstance(sequences, DigitalSequenceBlock):
         alphabet = alphabet or sequences.alphabet
         targets = sequences
@@ -1206,7 +1205,7 @@ def jackhmmer(
 
 def nhmmer(
     queries: typing.Union[_N, typing.Iterable[_N]],
-    sequences: typing.Iterable[DigitalSequence],
+    sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
     *,
     cpus: int = 0,
     callback: typing.Optional[typing.Callable[[_N, int], None]] = None,
@@ -1270,14 +1269,13 @@ def nhmmer(
         queries = (queries,)
 
     if isinstance(sequences, SequenceFile):
-        sequence_file: SequenceFile = sequences
-        if sequence_file.name is None:
+        if sequences.name is None:
             raise ValueError("expected named `SequenceFile` for targets")
-        if not sequence_file.digital:
+        if not sequences.digital:
             raise ValueError("expected digital mode `SequenceFile` for targets")
-        assert sequence_file.alphabet is not None
-        alphabet = alphabet or sequence_file.alphabet
-        targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequence_file
+        assert sequences.alphabet is not None
+        alphabet = alphabet or sequences.alphabet
+        targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequences
     elif isinstance(sequences, DigitalSequenceBlock):
         alphabet = alphabet or sequences.alphabet
         targets = sequences
@@ -1594,7 +1592,7 @@ if __name__ == "__main__":
             with HMMFile(args.hmmfile) as hmms:
                 if hmms.is_pressed():
                     hmms = hmms.optimized_profiles()  # type: ignore
-                hits_list = hmmsearch(hmms, sequences, cpus=args.jobs)  # type: ignore
+                hits_list = hmmsearch(hmms, sequences, cpus=args.jobs)
                 for hits in hits_list:
                     for hit in hits:
                         if hit.reported:
-- 
GitLab