From 950e723804eecd5d3d52975251f27a5c7bf0d8dd Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Tue, 8 Oct 2024 20:48:49 +0200
Subject: [PATCH] Reorganize detection and handling of alphabets in
 `pyhmmer.hmmer`

---
 pyhmmer/hmmer.py | 189 ++++++++++++++++++++++++++---------------------
 1 file changed, 103 insertions(+), 86 deletions(-)

diff --git a/pyhmmer/hmmer.py b/pyhmmer/hmmer.py
index bc5e4b47..05c511c0 100644
--- a/pyhmmer/hmmer.py
+++ b/pyhmmer/hmmer.py
@@ -85,9 +85,10 @@ if typing.TYPE_CHECKING:
     except ImportError:
         from typing_extensions import Literal, TypedDict, Unpack  # type: ignore
 
-    from .plan7 import BIT_CUTOFFS
+    from .plan7 import BIT_CUTOFFS, STRAND
 
     class PipelineOptions(TypedDict, total=False):
+        alphabet: Alphabet
         background: typing.Optional[Background]
         bias_filter: bool
         null2: bool
@@ -107,6 +108,14 @@ if typing.TYPE_CHECKING:
         incdomT: typing.Optional[float]
         bit_cutoffs: typing.Optional[BIT_CUTOFFS]
 
+    class LongTargetsPipelineOptions(PipelineOptions, total=False):
+        strand: typing.Optional[STRAND]
+        B1: int
+        B2: int
+        B3: int
+        block_length: int
+        window_length: typing.Optional[int]
+        window_beta: typing.Optional[float]
     
 
     
@@ -218,13 +227,12 @@ class _BaseWorker(typing.Generic[_Q, _T, _R], threading.Thread):
         callback: typing.Optional[typing.Callable[[_Q, int], None]],
         options: "PipelineOptions",
         pipeline_class: typing.Type[Pipeline],
-        alphabet: Alphabet,
         builder: typing.Optional[Builder] = None,
     ) -> None:
         super().__init__()
         self.options = options
         self.targets: _T = targets
-        self.pipeline = pipeline_class(alphabet=alphabet, **options)
+        self.pipeline = pipeline_class(**options)
         self.query_queue: "queue.Queue[typing.Optional[_Chore[_Q, _R]]]" = query_queue
         self.query_count = query_count
         self.callback: typing.Optional[typing.Callable[[_Q, int], None]] = (
@@ -333,9 +341,8 @@ class _JACKHMMERWorker(
         query_count: multiprocessing.Value,  # type: ignore
         kill_switch: threading.Event,
         callback: typing.Optional[typing.Callable[[_JACKHMMERQueryType, int], None]],
-        options: PipelineOptions,
+        options: "PipelineOptions",
         pipeline_class: typing.Type[Pipeline],
-        alphabet: Alphabet,
         builder: typing.Optional[Builder] = None,
         max_iterations: typing.Optional[int] = 5,
         select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None,
@@ -349,7 +356,6 @@ class _JACKHMMERWorker(
             callback=callback,
             options=options,
             pipeline_class=pipeline_class,
-            alphabet=alphabet,
             builder=builder,
         )
         self.select_hits = select_hits
@@ -469,7 +475,6 @@ class _BaseDispatcher(typing.Generic[_Q, _T, _R], abc.ABC):
         cpus: int = 0,
         callback: typing.Optional[typing.Callable[[_Q, int], None]] = None,
         pipeline_class: typing.Type[Pipeline] = Pipeline,
-        alphabet: Alphabet = Alphabet.amino(),
         builder: typing.Optional[Builder] = None,
         timeout: int = 1,
         **options,  # type: Unpack[PipelineOptions]
@@ -479,7 +484,6 @@ class _BaseDispatcher(typing.Generic[_Q, _T, _R], abc.ABC):
         self.callback: typing.Optional[typing.Callable[[_Q, int], None]] = callback
         self.options = options
         self.pipeline_class = pipeline_class
-        self.alphabet = alphabet
         self.builder = builder
         self.timeout = timeout
 
@@ -602,7 +606,7 @@ class _SEARCHDispatcher(
                 self.targets.name,
                 format=self.targets.format,
                 digital=True,
-                alphabet=self.alphabet,
+                alphabet=self.options["alphabet"],
             )
         else:
             targets = self.targets  # type: ignore
@@ -614,7 +618,6 @@ class _SEARCHDispatcher(
             self.callback,
             self.options,
             self.pipeline_class,
-            self.alphabet,
         )
 
 
@@ -637,7 +640,7 @@ class _PHMMERDispatcher(
                 self.targets.name,
                 format=self.targets.format,
                 digital=True,
-                alphabet=self.alphabet,
+                alphabet=self.options["alphabet"],
             )
         else:
             targets = self.targets  # type: ignore
@@ -649,7 +652,6 @@ class _PHMMERDispatcher(
             self.callback,
             self.options,
             self.pipeline_class,
-            self.alphabet,
             copy.copy(self.builder),
         )
 
@@ -673,7 +675,6 @@ class _JACKHMMERDispatcher(
             typing.Callable[[_JACKHMMERQueryType, int], None]
         ] = None,
         pipeline_class: typing.Type[Pipeline] = Pipeline,
-        alphabet: Alphabet = Alphabet.amino(),
         builder: typing.Optional[Builder] = None,
         timeout: int = 1,
         max_iterations: typing.Optional[int] = 5,
@@ -687,7 +688,6 @@ class _JACKHMMERDispatcher(
             cpus=cpus,
             callback=callback,
             pipeline_class=pipeline_class,
-            alphabet=alphabet,
             builder=builder,
             timeout=timeout,
             **options,
@@ -710,7 +710,6 @@ class _JACKHMMERDispatcher(
             self.callback,
             self.options,
             self.pipeline_class,
-            self.alphabet,
             copy.copy(self.builder),
             self.max_iterations,
             self.select_hits,
@@ -733,11 +732,10 @@ class _NHMMERDispatcher(
         callback: typing.Optional[
             typing.Callable[[_NHMMERQueryType, int], None]
         ] = None,
-        pipeline_class: typing.Type[Pipeline] = LongTargetsPipeline,
-        alphabet: Alphabet = Alphabet.dna(),
+        pipeline_class: typing.Type[LongTargetsPipeline] = LongTargetsPipeline,
         builder: typing.Optional[Builder] = None,
         timeout: int = 1,
-        **options,  # type: Unpack[PipelineOptions]
+        **options,  # type: Unpack[LongTargetsPipelineOptions]
     ) -> None:
         super().__init__(
             queries=queries,
@@ -745,10 +743,9 @@ class _NHMMERDispatcher(
             cpus=cpus,
             callback=callback,
             pipeline_class=pipeline_class,
-            alphabet=alphabet,
             builder=builder,
             timeout=timeout,
-            **options,
+            **options,  # type: ignore
         )
 
     def _new_thread(
@@ -763,7 +760,7 @@ class _NHMMERDispatcher(
                 self.targets.name,
                 format=self.targets.format,
                 digital=True,
-                alphabet=self.alphabet,
+                alphabet=self.options["alphabet"],
             )
         else:
             targets = self.targets  # type: ignore
@@ -775,7 +772,6 @@ class _NHMMERDispatcher(
             self.callback,
             self.options,
             self.pipeline_class,
-            self.alphabet,
             copy.copy(self.builder),
         )
 
@@ -806,7 +802,6 @@ class _SCANDispatcher(
             self.callback,
             self.options,
             self.pipeline_class,
-            self.alphabet,
         )
 
 
@@ -882,7 +877,8 @@ def hmmsearch(
         Queries may now be an iterable of different types, or a single object.
 
     """
-    _cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
+    cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
+    alphabet = options.get("alphabet")
 
     if not isinstance(queries, collections.abc.Iterable):
         queries = (queries,)
@@ -894,26 +890,27 @@ def hmmsearch(
         if not sequence_file.digital:
             raise ValueError("expected digital mode `SequenceFile` for targets")
         assert sequence_file.alphabet is not None
-        alphabet = sequence_file.alphabet
+        alphabet = alphabet or sequence_file.alphabet
         targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequence_file
     elif isinstance(sequences, DigitalSequenceBlock):
-        alphabet = sequences.alphabet
+        alphabet = alphabet or sequences.alphabet
         targets = sequences
     else:
         queries = peekable(queries)
         try:
-            alphabet = queries.peek().alphabet
+            alphabet = alphabet or queries.peek().alphabet or Alphabet.amino()
             targets = DigitalSequenceBlock(alphabet, sequences)
         except StopIteration:
-            alphabet = Alphabet.amino()
+            alphabet = alphabet or Alphabet.amino()
             targets = DigitalSequenceBlock(alphabet)
 
+    if "alphabet" not in options:
+        options["alphabet"] = alphabet
     dispatcher = _SEARCHDispatcher(
         queries=queries,
         targets=targets,
-        cpus=_cpus,
+        cpus=cpus,
         callback=callback,  # type: ignore
-        alphabet=alphabet,
         builder=None,
         pipeline_class=Pipeline,
         **options,
@@ -979,9 +976,8 @@ def phmmer(
         Queries may now be an iterable of different types, or a single object.
 
     """
-    _alphabet = Alphabet.amino()
-    _cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
-    _builder = Builder(_alphabet) if builder is None else builder
+    cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
+    alphabet = options.get("alphabet")
 
     if not isinstance(queries, collections.abc.Iterable):
         queries = (queries,)
@@ -993,23 +989,30 @@ def phmmer(
         if not sequence_file.digital:
             raise ValueError("expected digital mode `SequenceFile` for targets")
         assert sequence_file.alphabet is not None
-        alphabet = sequence_file.alphabet
+        alphabet = alphabet or sequence_file.alphabet
         targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequence_file
     elif isinstance(sequences, DigitalSequenceBlock):
-        alphabet = sequences.alphabet
+        alphabet = alphabet or sequences.alphabet
         targets = sequences
     else:
-        alphabet = _alphabet
-        targets = DigitalSequenceBlock(_alphabet, sequences)
+        alphabet = alphabet or Alphabet.amino()
+        targets = DigitalSequenceBlock(alphabet, sequences)
 
+    if builder is None:
+        builder = Builder(
+            alphabet, 
+            seed=options.get("seed", 42)
+        )
+
+    if "alphabet" not in options:
+        options["alphabet"] = alphabet
     dispatcher = _PHMMERDispatcher(
         queries=queries,
         targets=targets,
-        cpus=_cpus,
+        cpus=cpus,
         callback=callback,  # type: ignore
         pipeline_class=Pipeline,
-        alphabet=alphabet,
-        builder=_builder,
+        builder=builder,
         **options,
     )
     return dispatcher.run()  # type: ignore
@@ -1146,12 +1149,10 @@ def jackhmmer(
     .. versionadded:: 0.8.0
 
     """
-    _alphabet = Alphabet.amino()
-    _cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
-    _builder = Builder(_alphabet, architecture="hand") if builder is None else builder
-
+    cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
     options.setdefault("incE", 0.001)
     options.setdefault("incdomE", 0.001)
+    alphabet = options.get("alphabet")
 
     if not isinstance(queries, collections.abc.Iterable):
         queries = (queries,)
@@ -1163,27 +1164,35 @@ def jackhmmer(
         if not sequence_file.digital:
             raise ValueError("expected digital mode `SequenceFile` for targets")
         assert sequence_file.alphabet is not None
-        alphabet = sequence_file.alphabet
+        alphabet = alphabet or sequence_file.alphabet
         targets = typing.cast(DigitalSequenceBlock, sequence_file.read_block())
     elif isinstance(sequences, DigitalSequenceBlock):
-        alphabet = sequences.alphabet
+        alphabet = alphabet or sequences.alphabet
         targets = sequences
     else:
-        alphabet = _alphabet
-        targets = DigitalSequenceBlock(_alphabet, sequences)
+        alphabet = alphabet or Alphabet.amino()
+        targets = DigitalSequenceBlock(alphabet, sequences)
+
+    if builder is None:
+        builder = Builder(
+            alphabet,
+            seed=options.get("seed", 42),
+            architecture="hand",
+        )
 
-    dispatcher = _JACKHMMERDispatcher(
+    if "alphabet" not in options:
+        options["alphabet"] = alphabet
+    dispatcher = _JACKHMMERDispatcher(  # type: ignore
         queries=queries,
         targets=targets,
-        cpus=_cpus,
+        cpus=cpus,
         callback=callback,
         pipeline_class=Pipeline,
-        alphabet=alphabet,
-        builder=_builder,
+        builder=builder,
         max_iterations=max_iterations,
         select_hits=select_hits,
         checkpoints=checkpoints,
-        **options,  # type: ignore
+        **options,
     )
     return dispatcher.run()
 
@@ -1198,7 +1207,7 @@ def nhmmer(
     cpus: int = 0,
     callback: typing.Optional[typing.Callable[[_N, int], None]] = None,
     builder: typing.Optional[Builder] = None,
-    **options,  # type: Unpack[PipelineOptions]
+    **options,  # type: Unpack[LongTargetsPipelineOptions]
 ) -> typing.Iterator["TopHits[_N]"]:
     """Search nucleotide sequences against a sequence database.
 
@@ -1250,18 +1259,8 @@ def nhmmer(
         Queries may now be an iterable of different types, or a single object.
 
     """
-    _alphabet = Alphabet.dna()
-    _cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
-
-    if builder is None:
-        _builder = Builder(
-            _alphabet,
-            seed=typing.cast(int, options.get("seed", 42)),
-            window_length=typing.cast(typing.Optional[int], options.get("window_length")),
-            window_beta=typing.cast(typing.Optional[float], options.get("window_beta")),
-        )
-    else:
-        _builder = builder
+    cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
+    alphabet = options.get("alphabet")
 
     if not isinstance(queries, collections.abc.Iterable):
         queries = (queries,)
@@ -1273,23 +1272,32 @@ def nhmmer(
         if not sequence_file.digital:
             raise ValueError("expected digital mode `SequenceFile` for targets")
         assert sequence_file.alphabet is not None
-        alphabet = sequence_file.alphabet
+        alphabet = alphabet or sequence_file.alphabet
         targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequence_file
     elif isinstance(sequences, DigitalSequenceBlock):
-        alphabet = sequences.alphabet
+        alphabet = alphabet or sequences.alphabet
         targets = sequences
     else:
-        alphabet = _alphabet
-        targets = DigitalSequenceBlock(_alphabet, sequences)
+        alphabet = alphabet or Alphabet.dna()
+        targets = DigitalSequenceBlock(alphabet, sequences)
+
+    if builder is None:
+        builder = Builder(
+            alphabet,
+            seed=options.get("seed", 42),
+            window_length=options.get("window_length"),
+            window_beta=options.get("window_beta"),
+        )
 
+    if "alphabet" not in options:
+        options["alphabet"] = alphabet
     dispatcher = _NHMMERDispatcher(
         queries=queries,
         targets=targets,
-        cpus=_cpus,
+        cpus=cpus,
         callback=callback,  # type: ignore
         pipeline_class=LongTargetsPipeline,
-        alphabet=alphabet,
-        builder=_builder,
+        builder=builder,
         **options,
     )
     return dispatcher.run()  # type: ignore
@@ -1426,7 +1434,6 @@ def hmmscan(
     *,
     cpus: int = 0,
     callback: typing.Optional[typing.Callable[[DigitalSequence, int], None]] = None,
-    background: typing.Optional[Background] = None,
     **options,  # type: Unpack[PipelineOptions]
 ) -> typing.Iterator["TopHits[DigitalSequence]"]:
     """Scan query sequences against a profile database.
@@ -1499,26 +1506,29 @@ def hmmscan(
     .. versionadded:: 0.7.0
 
     """
-    _alphabet = Alphabet.amino()
-    _cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
-    _background = Background(_alphabet) if background is None else background
-    options.setdefault("background", _background)
+    cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
+    alphabet = options.get("alphabet")
+    background = options.get("background")
 
     if not isinstance(queries, collections.abc.Iterable):
         queries = (queries,)
     if isinstance(profiles, HMMPressedFile):
-        alphabet = _alphabet  # FIXME: try to detect from content instead?
+        alphabet = alphabet or Alphabet.amino() # FIXME: try to detect alphabet?
         targets = profiles
     elif isinstance(profiles, OptimizedProfileBlock):
-        alphabet = profiles.alphabet
+        alphabet = alphabet or profiles.alphabet
         targets = profiles  # type: ignore
     else:
-        alphabet = _alphabet
-        block = OptimizedProfileBlock(_alphabet)
+        block = None
         for item in profiles:
+            alphabet = alphabet or item.alphabet
+            if block is None:
+                block = OptimizedProfileBlock(item.alphabet)
             if isinstance(item, HMM):
+                if background is None:
+                    background = Background(item.alphabet)
                 profile = Profile(item.M, item.alphabet)
-                profile.configure(item, _background)
+                profile.configure(item, background)
                 item = profile
             if isinstance(item, Profile):
                 item = item.to_optimized()
@@ -1529,17 +1539,24 @@ def hmmscan(
                 raise TypeError(
                     "Expected HMM, Profile or OptimizedProfile, found {}".format(ty)
                 )
+        if alphabet is None:
+            alphabet = Alphabet.amino()
+        if block is None:
+            block = OptimizedProfileBlock(alphabet)
         targets = block  # type: ignore
 
+    if "alphabet" not in options:
+        options["alphabet"] = alphabet
+    if "background" not in options and background is not None:
+        options["background"] = background
     dispatcher = _SCANDispatcher(
         queries=queries,
         targets=targets,
-        cpus=_cpus,
+        cpus=cpus,
         callback=callback,
         pipeline_class=Pipeline,
-        alphabet=alphabet,
         builder=None,
-        **options,  # type: ignore
+        **options,
     )
     return dispatcher.run()
 
-- 
GitLab