hmmer.py 35.2 KB
Newer Older
1
# coding: utf-8
2
"""Reimplementation of HMMER binaries with the pyHMMER API.
3
"""
4

5
import abc
6
import contextlib
7
import collections
8
import ctypes
9
import itertools
10
import io
11
import queue
12
import time
13
import threading
14
import typing
15
import os
16
17
import multiprocessing

18
19
import psutil

20
from .easel import Alphabet, DigitalSequence, DigitalMSA, MSA, MSAFile, TextSequence, SequenceFile, SSIWriter
21
from .plan7 import Builder, Background, Pipeline, PipelineSearchTargets, LongTargetsPipeline, TopHits, HMM, HMMFile, Profile, TraceAligner, OptimizedProfile
22
from .utils import peekable
23

24
# the query type for the pipeline
25
26
27
28
29
_Q = typing.TypeVar("_Q", HMM, Profile, OptimizedProfile, DigitalSequence, DigitalMSA)
# the model type for the pipeline
_M = typing.TypeVar("_M", HMM, Profile, OptimizedProfile)
# the sequence type for the pipeline
_S = typing.TypeVar("_S", DigitalSequence, DigitalMSA)
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# --- Result class -----------------------------------------------------------

class _ResultBuffer:

    event: threading.Event
    hits: typing.Optional[TopHits]
    exception: typing.Optional[BaseException]

    __slots__ = ("event", "hits", "exception")

    def __init__(self) -> None:
        self.event = threading.Event()
        self.hits = None
        self.exception = None

    def available(self) -> bool:
        return self.event.is_set()

    def get(self) -> TopHits:
        self.event.wait()
        if self.exception is not None:
            raise self.exception
        return typing.cast(TopHits, self.hits)

    def set(self, hits: TopHits) -> None:
        self.hits = hits
        self.event.set()

    def fail(self, exception: BaseException) -> None:
        self.exception = exception
        self.event.set()


64
# --- Pipeline threads -------------------------------------------------------
65
66

class _PipelineThread(typing.Generic[_Q], threading.Thread):
67
68
69
    """A generic worker thread to parallelize a pipelined search.

    Attributes:
70
71
        sequence (`pyhmmer.plan7.PipelineSearchTargets`): The target
            sequences to search for hits.
72
        query_queue (`queue.Queue`): The queue used to pass queries between
73
74
75
            threads. It contains both the query, its index so that the
            results can be returned in the same order, and a `_ResultBuffer`
            where to store the result when the query has been processed.
76
77
78
79
80
81
82
83
84
85
86
87
        query_count (`multiprocessing.Value`): An atomic counter storing
            the total number of queries that have currently been loaded.
            Passed to the ``callback`` so that an UI can show the total
            for a progress bar.
        kill_switch (`threading.Event`): An event flag shared between
            all worker threads, used to notify emergency exit.
        callback (`callable`, optional): An optional callback to be called
            after each query has been processed. It should accept two
            arguments: the query object that was processed, and the total
            number of queries read until now.
        options (`dict`): A dictionary of options to be passed to the
            `pyhmmer.plan7.Pipeline` object wrapped by the worker thread.
88
89
90
        pipeline_class (`type`): The pipeline class to use to search for
            hits. Use `~plan7.LongTargetsPipeline` for `nhmmer`, and
            `~plan7.Pipeline` everywhere else.
91
92
93

    """

94
    @staticmethod
95
    def _none_callback(hmm: _Q, total: int) -> None:
96
97
        pass

98
99
    def __init__(
        self,
100
        sequences: PipelineSearchTargets,
101
        query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]]",
102
        query_count: multiprocessing.Value,  # type: ignore
103
        kill_switch: threading.Event,
104
        callback: typing.Optional[typing.Callable[[_Q, int], None]],
105
        options: typing.Dict[str, typing.Any],
106
107
        pipeline_class: typing.Type[Pipeline],
        alphabet: Alphabet,
108
    ) -> None:
109
        super().__init__()
110
        self.options = options
111
        self.sequences = sequences
112
        self.pipeline = pipeline_class(alphabet=alphabet, **options)
113
        self.query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]]" = query_queue
114
        self.query_count = query_count
115
        self.callback: typing.Optional[typing.Callable[[_Q, int], None]] = callback or self._none_callback
116
        self.kill_switch = kill_switch
117
        self.error: typing.Optional[BaseException] = None
118

119
    def run(self) -> None:
120
        while not self.kill_switch.is_set():
121
122
123
124
125
126
127
128
129
            # attempt to get the next argument, with a timeout
            # so that the thread can periodically check if it has
            # been killed, even when the query queue is empty
            try:
                args = self.query_queue.get(timeout=1)
            except queue.Empty:
                continue
            # check if arguments from the queue are a poison-pill (`None`),
            # in which case the thread will stop running
130
            if args is None:
131
                self.query_queue.task_done()
132
                return
133
            else:
134
                index, query, result_buffer = args
135
136
137
            # process the arguments, making sure to capture any exception
            # raised while processing the query, and then mark the hits
            # as "found" using a `threading.Event` for each.
138
            try:
139
                hits = self.process(index, query)
140
                self.query_queue.task_done()
141
                result_buffer.set(hits)
142
143
144
            except BaseException as exc:
                self.error = exc
                self.kill()
145
                result_buffer.fail(exc)
146

147
    def kill(self) -> None:
148
        self.kill_switch.set()
149

150
    def process(self, index: int, query: _Q) -> TopHits:
151
152
153
        hits = self.search(query)
        self.callback(query, self.query_count.value)  # type: ignore
        self.pipeline.clear()
154
        return hits
155

156
157
158
159
160
    @abc.abstractmethod
    def search(self, query: _Q) -> TopHits:
        return NotImplemented


161
162
class _ModelPipelineThread(typing.Generic[_M], _PipelineThread[_M]):
    def search(self, query: _M) -> TopHits:
163
164
165
166
167
168
        return self.pipeline.search_hmm(query, self.sequences)


class _SequencePipelineThread(_PipelineThread[DigitalSequence]):
    def __init__(
        self,
169
        sequences: PipelineSearchTargets,
170
        query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]]",
171
172
        query_count: multiprocessing.Value,  # type: ignore
        kill_switch: threading.Event,
173
        callback: typing.Optional[typing.Callable[[DigitalSequence, int], None]],
174
        options: typing.Dict[str, typing.Any],
175
176
        pipeline_class: typing.Type[Pipeline],
        alphabet: Alphabet,
177
178
179
180
181
182
183
184
185
        builder: Builder,
    ) -> None:
        super().__init__(
            sequences,
            query_queue,
            query_count,
            kill_switch,
            callback,
            options,
186
187
            pipeline_class,
            alphabet,
188
189
190
191
192
193
        )
        self.builder = builder

    def search(self, query: DigitalSequence) -> TopHits:
        return self.pipeline.search_seq(query, self.sequences, self.builder)

194

195
196
197
class _MSAPipelineThread(_PipelineThread[DigitalMSA]):
    def __init__(
        self,
198
        sequences: PipelineSearchTargets,
199
        query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]]",
200
201
202
203
        query_count: multiprocessing.Value,  # type: ignore
        kill_switch: threading.Event,
        callback: typing.Optional[typing.Callable[[DigitalMSA, int], None]],
        options: typing.Dict[str, typing.Any],
204
205
        pipeline_class: typing.Type[Pipeline],
        alphabet: Alphabet,
206
207
208
        builder: Builder,
    ) -> None:
        super().__init__(
209
210
211
212
213
214
            sequences,
            query_queue,
            query_count,
            kill_switch,
            callback,
            options,
215
216
            pipeline_class,
            alphabet,
217
        )
218
219
220
221
        self.builder = builder

    def search(self, query: DigitalMSA) -> TopHits:
        return self.pipeline.search_msa(query, self.sequences, self.builder)
222
223


224
225
226
227
228
229
230
# --- Search runners ---------------------------------------------------------

class _Search(typing.Generic[_Q], abc.ABC):

    def __init__(
        self,
        queries: typing.Iterable[_Q],
231
        sequences: typing.Iterable[DigitalSequence],
232
        cpus: int = 0,
233
        callback: typing.Optional[typing.Callable[[_Q, int], None]] = None,
234
235
        pipeline_class: typing.Type[Pipeline] = Pipeline,
        alphabet: Alphabet = Alphabet.amino(),
236
        **options # type: typing.Dict[str, object]
237
    ) -> None:
238
        self.queries: typing.Iterable[_Q] = queries
239
        self.cpus = cpus
240
        self.callback: typing.Optional[typing.Callable[[_Q, int], None]] = callback
241
        self.options = options
242
243
        self.pipeline_class = pipeline_class
        self.alphabet = alphabet
244
245
246
247
        if isinstance(sequences, PipelineSearchTargets):
            self.sequences = sequences
        else:
            self.sequences = PipelineSearchTargets(sequences)
248
249
250
251

    @abc.abstractmethod
    def _new_thread(
        self,
252
        query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _Q, _ResultBuffer]]]",
253
        query_count: "multiprocessing.Value[int]",  # type: ignore
254
255
        kill_switch: threading.Event,
    ) -> _PipelineThread[_Q]:
256
        return NotImplemented
257
258
259
260
261
262
263
264
265

    def _single_threaded(self) -> typing.Iterator[TopHits]:
        # create the queues to pass the HMM objects around, as well as atomic
        # values that we use to synchronize the threads
        query_queue = queue.Queue()  # type: ignore
        query_count = multiprocessing.Value(ctypes.c_ulong)
        kill_switch = threading.Event()

        # create the thread (to recycle code)
266
        thread = self._new_thread(query_queue, query_count, kill_switch)
267

268
269
270
        # process each HMM iteratively and yield the result
        # immediately so that the user can iterate over the
        # TopHits one at a time
271
272
        for index, query in enumerate(self.queries):
            query_count.value += 1
273
            yield thread.process(index, query)
274
275
276
277

    def _multi_threaded(self) -> typing.Iterator[TopHits]:
        # create the queues to pass the HMM objects around, as well as atomic
        # values that we use to synchronize the threads
278
        results: typing.Deque[_ResultBuffer] = collections.deque()
279
280
        query_count = multiprocessing.Value(ctypes.c_ulong)
        kill_switch = threading.Event()
281
282
283
        # the query queue is bounded so that we only feed more queries
        # if the worker threads are waiting for some
        query_queue = queue.Queue(maxsize=self.cpus)  # type: ignore
284
285
286
        # additional type annotations
        query: typing.Optional[_Q]
        index: int
287
288
289
290

        # create and launch one pipeline thread per CPU
        threads = []
        for _ in range(self.cpus):
291
            thread = self._new_thread(query_queue, query_count, kill_switch)
292
293
294
            thread.start()
            threads.append(thread)

295
296
        # catch exceptions to kill threads in the background before exiting
        try:
297
298
299
300
301
302
            # enumerate queries, so that we now the index of each query
            # and we can yield the results in the same order
            queries = enumerate(self.queries)
            # initially feed one query per thread so that they can start
            # working before we enter the main loop
            for (index, query) in itertools.islice(queries, self.cpus):
303
                query_count.value += 1
304
305
306
                query_result = _ResultBuffer()
                query_queue.put((index, query, query_result))
                results.append(query_result)
307
308
            # alternate between feeding queries to the threads and
            # yielding back results, if available
309
            while results:
310
311
312
313
314
315
                # get the next query, or break the loop if there is no query
                # left to process in the input iterator.
                index, query = next(queries, (-1, None))
                if query is None:
                    break
                else:
316
                    query_count.value += 1
317
                    query_result = _ResultBuffer()
318
                    query_queue.put((index, query))
319
                    results.append(query_result)
320
                # yield the top hits for the next query, if available
321
322
323
                if results[0].available():
                    yield results[0].get()
                    results.popleft()
324
325
            # now that we exhausted all queries, poison pill the
            # threads so they stop on their own
326
327
            for _ in threads:
                query_queue.put(None)
328
            # yield remaining results
329
330
331
            while results:
                yield results[0].get()
                results.popleft()
332
333
334
335
336
        except BaseException:
            # make sure threads are killed to avoid being stuck,
            # e.g. after a KeyboardInterrupt
            kill_switch.set()
            raise
337
338
339
340
341
342
343
344

    def run(self) -> typing.Iterator[TopHits]:
        if self.cpus == 1:
            return self._single_threaded()
        else:
            return self._multi_threaded()


345
class _ModelSearch(typing.Generic[_M], _Search[_M]):
346
347
348

    def _new_thread(
        self,
349
        query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, _M, _ResultBuffer]]]",
350
        query_count: "multiprocessing.Value[int]",  # type: ignore
351
        kill_switch: threading.Event,
352
353
    ) -> _ModelPipelineThread[_M]:
        return _ModelPipelineThread(
354
355
356
357
358
359
            self.sequences,
            query_queue,
            query_count,
            kill_switch,
            self.callback,
            self.options,
360
361
            self.pipeline_class,
            self.alphabet,
362
363
364
365
366
367
368
369
370
        )


class _SequenceSearch(_Search[DigitalSequence]):

    def __init__(
        self,
        builder: Builder,
        queries: typing.Iterable[DigitalSequence],
371
        sequences: typing.Iterable[DigitalSequence],
372
373
        cpus: int = 0,
        callback: typing.Optional[typing.Callable[[DigitalSequence, int], None]] = None,
374
375
        pipeline_class: typing.Type[Pipeline] = Pipeline,
        alphabet: Alphabet = Alphabet.amino(),
376
        **options, # type: typing.Dict[str, object]
377
    ) -> None:
378
        super().__init__(queries, sequences, cpus, callback, pipeline_class, alphabet, **options)
379
380
381
382
        self.builder = builder

    def _new_thread(
        self,
383
        query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalSequence, _ResultBuffer]]]",
384
        query_count: "multiprocessing.Value[int]",  # type: ignore
385
386
387
388
389
390
391
392
393
        kill_switch: threading.Event,
    ) -> _SequencePipelineThread:
        return _SequencePipelineThread(
            self.sequences,
            query_queue,
            query_count,
            kill_switch,
            self.callback,
            self.options,
394
395
            self.pipeline_class,
            self.alphabet,
396
397
398
            self.builder.copy(),
        )

399

400
401
402
403
404
405
class _MSASearch(_Search[DigitalMSA]):

    def __init__(
        self,
        builder: Builder,
        queries: typing.Iterable[DigitalMSA],
406
        sequences: typing.Iterable[DigitalSequence],
407
408
        cpus: int = 0,
        callback: typing.Optional[typing.Callable[[DigitalMSA, int], None]] = None,
409
410
        pipeline_class: typing.Type[Pipeline] = Pipeline,
        alphabet: Alphabet = Alphabet.amino(),
411
        **options, # type: typing.Dict[str, object]
412
    ) -> None:
413
        super().__init__(queries, sequences, cpus, callback, pipeline_class, alphabet, **options)
414
415
416
417
        self.builder = builder

    def _new_thread(
        self,
418
        query_queue: "queue.Queue[typing.Optional[typing.Tuple[int, DigitalMSA, _ResultBuffer]]]",
419
        query_count: "multiprocessing.Value[int]",  # type: ignore
420
421
422
423
424
425
426
427
428
        kill_switch: threading.Event,
    ) -> _MSAPipelineThread:
        return _MSAPipelineThread(
            self.sequences,
            query_queue,
            query_count,
            kill_switch,
            self.callback,
            self.options,
429
430
            self.pipeline_class,
            self.alphabet,
431
432
433
434
            self.builder.copy(),
        )


435
# --- hmmsearch --------------------------------------------------------------
436

437
def hmmsearch(
438
    queries: typing.Iterable[_M],
439
    sequences: typing.Iterable[DigitalSequence],
440
    cpus: int = 0,
441
    callback: typing.Optional[typing.Callable[[_M, int], None]] = None,
442
    **options,  # type: typing.Dict[str, object]
443
) -> typing.Iterator[TopHits]:
444
445
446
    """Search HMM profiles against a sequence database.

    Arguments:
447
448
        queries (iterable of `HMM`, `Profile` or `OptimizedProfile`): The
            query HMMs or profiles to search for in the database.
449
450
        sequences (collection of `~pyhmmer.easel.DigitalSequence`): A
            database of sequences to query.
451
452
        cpus (`int`): The number of threads to run in parallel. Pass ``1``
            to run everything in the main thread, ``0`` to automatically
453
454
            select a suitable number (using `psutil.cpu_count`), or any
            positive number otherwise.
455
456
457
458
459
        callback (callable): A callback that is called everytime a query is
            processed with two arguments: the query, and the total number
            of queries. This can be used to display progress in UI.

    Yields:
460
461
        `~pyhmmer.plan7.TopHits`: An object reporting *top hits* for each
        query, in the same order the queries were passed in the input.
462

463
464
465
466
    Raises:
        `~pyhmmer.errors.AlphabetMismatch`: When any of the query HMMs
        and the sequences do not share the same alphabet.

467
468
469
470
    Note:
        Any additional arguments passed to the `hmmsearch` function will be
        passed transparently to the `~pyhmmer.plan7.Pipeline` to be created.

471
472
    .. versionadded:: 0.1.0

473
474
475
    .. versionchanged:: 0.4.9
       Allow using `Profile` and `OptimizedProfile` queries.

476
    """
477
    # count the number of CPUs to use
478
    _cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
479
    runner: _ModelSearch[_M] = _ModelSearch(queries, sequences, _cpus, callback, **options) # type: ignore
480
    return runner.run()
481
482


483
484
# --- phmmer -----------------------------------------------------------------

485
def phmmer(
486
    queries: typing.Iterable[_S],
487
    sequences: typing.Iterable[DigitalSequence],
488
    cpus: int = 0,
489
    callback: typing.Optional[typing.Callable[[_S, int], None]] = None,
490
    builder: typing.Optional[Builder] = None,
491
    **options, # type: typing.Dict[str, object]
492
) -> typing.Iterator[TopHits]:
493
494
495
    """Search protein sequences against a sequence database.

    Arguments:
496
497
        queries (iterable of `DigitalSequence` or `DigitalMSA`): The
            query sequences to search for in the sequence database.
498
499
500
501
        sequences (collection of `~pyhmmer.easel.DigitalSequence`): A
            database of sequences to query.
        cpus (`int`): The number of threads to run in parallel. Pass ``1`` to
            run everything in the main thread, ``0`` to automatically
502
503
            select a suitable number (using `psutil.cpu_count`), or any
            positive number otherwise.
504
505
506
507
508
509
510
511
512
513
514
515
        callback (callable): A callback that is called everytime a query is
            processed with two arguments: the query, and the total number
            of queries. This can be used to display progress in UI.
        builder (`~pyhmmer.plan7.Builder`, optional): A builder to configure
            how the queries are converted to HMMs. Passing `None` will create
            a default instance.

    Yields:
        `~pyhmmer.plan7.TopHits`: A *top hits* instance for each query,
        in the same order the queries were passed in the input.

    Note:
516
517
518
519
520
        Any additional keyword arguments passed to the `phmmer` function
        will be passed transparently to the `~pyhmmer.plan7.Pipeline` to
        be created in each worker thread.

    .. versionadded:: 0.2.0
521

522
523
524
    .. versionchanged:: 0.3.0
       Allow using `DigitalMSA` queries.

525
    """
526
    _cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
527
    _builder = Builder(Alphabet.amino()) if builder is None else builder
528
529

    try:
530
531
        _queries: peekable[typing.Union[DigitalSequence, DigitalMSA, HMM]] = peekable(queries)
        _item: typing.Union[DigitalSequence, DigitalMSA, HMM, None] = _queries.peek()
532
    except StopIteration:
533
        _item = None
534

535
    runner: typing.Union[_SequenceSearch, _MSASearch]
536
    if _item is None or isinstance(_item, DigitalSequence):
537
        runner = _SequenceSearch(
538
539
540
541
542
543
544
545
            _builder,
            typing.cast(peekable[DigitalSequence], _queries),
            sequences,
            _cpus,
            callback,  # type: ignore
            pipeline_class=Pipeline,
            alphabet=Alphabet.amino(),
            **options
546
        )
547
    elif isinstance(_item, DigitalMSA):
548
        runner = _MSASearch(
549
            _builder, _queries, sequences, _cpus, callback, pipeline_class=Pipeline, alphabet=Alphabet.amino(), **options   # type: ignore
550
        )
551
552
553
554
    else:
        name = type(_item).__name__
        raise TypeError(f"Expected iterable of DigitalSequence or DigitalMSA, found {name}")

555
556
557
558
559
    return runner.run()


# --- nhmmer -----------------------------------------------------------------

560
def nhmmer(
561
    queries: typing.Iterable[_Q],
562
    sequences: typing.Iterable[DigitalSequence],
563
    cpus: int = 0,
564
    callback: typing.Optional[typing.Callable[[_Q, int], None]] = None,
565
    builder: typing.Optional[Builder] = None,
566
    **options, # type: typing.Dict[str, object]
567
) -> typing.Iterator[TopHits]:
568
    """Search nucleotide sequences against a sequence database.
569

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
    Arguments:
        queries (iterable of `DigitalSequence`, `DigitalMSA`, `HMM`): The
            query sequences or profiles to search for in the sequence
            database.
        sequences (collection of `~pyhmmer.easel.DigitalSequence`): A
            database of sequences to query.
        cpus (`int`): The number of threads to run in parallel. Pass ``1`` to
            run everything in the main thread, ``0`` to automatically
            select a suitable number (using `psutil.cpu_count`), or any
            positive number otherwise.
        callback (callable): A callback that is called everytime a query is
            processed with two arguments: the query, and the total number
            of queries. This can be used to display progress in UI.
        builder (`~pyhmmer.plan7.Builder`, optional): A builder to configure
            how the queries are converted to HMMs. Passing `None` will create
            a default instance.

    Yields:
        `~pyhmmer.plan7.TopHits`: A *top hits* instance for each query,
        in the same order the queries were passed in the input.

591
    Note:
592
        Any additional keyword arguments passed to the `nhmmer` function
593
594
595
596
        will be passed to the `~pyhmmer.plan7.LongTargetsPipeline` created
        in each worker thread. The ``strand`` argument can be used to
        restrict the search on the direct or reverse strand.

597
598
599
600
601
602
603
    Hint:
        This function is not just `phmmer` for nucleotide sequences; it
        actually uses a `~pyhmmer.plan7.LongTargetsPipeline` internally
        instead of processing each target sequence in its entirety when
        searching for hits. This avoids hitting the maximum target size
        that can be used (100,000 residues), which may be a problem for
        some larger genomes.
604
605
606

    .. versionadded:: 0.3.0

607
608
609
    .. versionchanged:: 0.4.9
       Allow using `Profile` and `OptimizedProfile` queries.

610
    """
611
    _cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1
612
613
614
    _builder = Builder(Alphabet.dna()) if builder is None else builder

    try:
615
616
        _queries: peekable[_Q] = peekable(queries)
        _item: typing.Optional[_Q] = _queries.peek()
617
    except StopIteration:
618
        _item = None
619

620
    runner: typing.Union[_SequenceSearch, _MSASearch, _ModelSearch[HMM]]
621
    if _item is None or isinstance(_item, DigitalSequence):
622
        runner = _SequenceSearch(
623
            _builder,
624
            typing.cast(peekable[DigitalSequence], _queries),
625
626
            sequences,
            _cpus,
627
            callback,  # type: ignore
628
            pipeline_class=LongTargetsPipeline,
629
630
            alphabet=_item.alphabet if _item is not None else Alphabet.dna(),  # type: ignore
            **options,
631
        )
632
    elif isinstance(_item, DigitalMSA):
633
        runner = _MSASearch(
634
            _builder,
635
            typing.cast(peekable[DigitalMSA], _queries),
636
637
638
639
640
            sequences,
            _cpus,
            callback,
            pipeline_class=LongTargetsPipeline,
            alphabet=_item.alphabet,
641
            **options,
642
        )
643
644
645
    elif isinstance(_item, (HMM, Profile, OptimizedProfile)):
        runner = _ModelSearch(
            typing.cast(peekable[HMM], _queries),
646
647
            sequences,
            _cpus,
648
            callback,  # type: ignore
649
650
651
            pipeline_class=LongTargetsPipeline,
            alphabet=_item.alphabet,
            **options,
652
        )
653
654
    else:
        name = type(_item).__name__
655
        raise TypeError(f"Expected iterable of DigitalSequence, DigitalMSA, HMM, Profile or OptimizedProfile, found {name}")
656
    return runner.run()
657

658

659
660
# --- hmmpress ---------------------------------------------------------------

661
def hmmpress(
662
    hmms: typing.Iterable[HMM], output: typing.Union[str, "os.PathLike[str]"],
663
) -> int:
664
665
666
667
668
669
670
671
672
673
674
675
676
    """Press several HMMs into a database.

    Calling this function will create 4 files at the given location:
    ``{output}.h3p`` (containing the optimized profiles),
    ``{output}.h3m`` (containing the binary HMMs),
    ``{output}.h3f`` (containing the MSV parameters), and
    ``{output}.h3i`` (the SSI index mapping the previous files).

    Arguments:
        hmms (iterable of `~pyhmmer.plan7.HMM`): The HMMs to be pressed
            together in the file.
        output (`str` or `os.PathLike`): The path to an output location
            where to write the different files.
677

678
    """
679
    DEFAULT_L = 400
680
681
    path = os.fspath(output)
    nmodel = 0
682
683

    with contextlib.ExitStack() as ctx:
684
685
686
687
688
        h3p = ctx.enter_context(open("{}.h3p".format(path), "wb"))
        h3m = ctx.enter_context(open("{}.h3m".format(path), "wb"))
        h3f = ctx.enter_context(open("{}.h3f".format(path), "wb"))
        h3i = ctx.enter_context(SSIWriter("{}.h3i".format(path)))
        fh = h3i.add_file(path, format=0)
689

690
        for hmm in hmms:
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
            # create the background model on the first iteration
            if nmodel == 0:
                bg = Background(hmm.alphabet)
                bg.L = DEFAULT_L

            # build the optimized models
            gm = Profile(hmm.M, hmm.alphabet)
            gm.configure(hmm, bg, DEFAULT_L)
            om = gm.optimized()

            # update the disk offsets of the optimized model to be written
            om.offsets.model = h3m.tell()
            om.offsets.profile = h3p.tell()
            om.offsets.filter = h3f.tell()

706
707
708
            # check that hmm has a name
            if hmm.name is None:
                raise ValueError("HMMs must have a name to be pressed.")
709
            # add the HMM name, and optionally the HMM accession to the index
710
            h3i.add_key(hmm.name, fh, om.offsets.model, 0, 0)
711
712
713
714
715
716
            if hmm.accession is not None:
                h3i.add_alias(hmm.accession, hmm.name)

            # write the HMM in binary format, and the optimized profile
            hmm.write(h3m, binary=True)
            om.write(h3f, h3p)
717
718
719
720
            nmodel += 1

    # return the number of written HMMs
    return nmodel
721
722


723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
# --- hmmalign ---------------------------------------------------------------

def hmmalign(
    hmm: HMM,
    sequences: typing.Collection[DigitalSequence],
    trim: bool = False,
    digitize: bool = False,
    all_consensus_cols: bool = True,
) -> MSA:
    """Align several sequences to a reference HMM, and return the MSA.

    Arguments:
        hmm (`~pyhmmer.plan7.HMM`): The reference HMM to use for the
            alignment.
        sequences (collection of `~pyhmmer.easel.DigitalSequence`): The
            sequences to align to the HMM.
        trim (`bool`): Trim off any residues that get assigned to
            flanking :math:`N` and :math:`C` states (in profile traces)
            or :math:`I_0` and :math:`I_m` (in core traces).
        digitize (`bool`): If set to `True`, returns a `DigitalMSA`
            instead of a `TextMSA`.
        all_consensus_cols (`bool`): Force a column to be created for
            every consensus column in the model, even if it means having
            all gap character in a column.

    Returns:
        `~pyhmmer.easel.MSA`: A multiple sequence alignment containing
        the aligned sequences, either a `TextMSA` or a `DigitalMSA`
        depending on the value of the ``digitize`` argument.

    See Also:
        The `~pyhmmer.plan7.TraceAligner` class, which lets you inspect the
        intermediate tracebacks obtained for each alignment before building
        a MSA.

758
759
    .. versionadded:: 0.4.7

760
761
762
763
764
765
766
767
768
769
770
771
772
    """
    aligner = TraceAligner()
    traces = aligner.compute_traces(hmm, sequences)
    return aligner.align_traces(
        hmm,
        sequences,
        traces,
        trim=trim,
        digitize=digitize,
        all_consensus_cols=all_consensus_cols
    )


773
774
775
776
777
778
779
# add a very limited CLI so that this module can be invoked in a shell:
#     $ python -m pyhmmer.hmmsearch <hmmfile> <seqdb>
if __name__ == "__main__":

    import argparse
    import sys

780
    def _hmmsearch(args: argparse.Namespace) -> int:
781
782
783
784
785
786
        try:
            with SequenceFile(args.seqdb, digital=True) as seqfile:
                sequences: typing.List[DigitalSequence] = list(seqfile)  # type: ignore
        except EOFError as err:
            print(err, file=sys.stderr)
            return 1
787

788
        with HMMFile(args.hmmfile) as hmms:
789
            queries = hmms.optimized_profiles() if hmms.is_pressed() else hmms
790
            hits_list = hmmsearch(queries, sequences, cpus=args.jobs)  # type: ignore
791
792
793
794
795
796
            for hits in hits_list:
                for hit in hits:
                    if hit.is_reported():
                        print(
                            hit.name.decode(),
                            "-",
797
798
                            hit.best_domain.alignment.hmm_accession.decode(),
                            hit.best_domain.alignment.hmm_name.decode(),
799
800
801
802
803
                            hit.evalue,
                            hit.score,
                            hit.bias,
                            sep="\t",
                        )
804

805
        return 0
806

807
    def _phmmer(args: argparse.Namespace) -> int:
808
809
810
        alphabet = Alphabet.amino()

        with SequenceFile(args.seqdb, digital=True, alphabet=alphabet) as seqfile:
811
            sequences = list(seqfile)
812

813
        with SequenceFile(args.seqfile, digital=True, alphabet=alphabet) as queries:
814
            hits_list = phmmer(queries, sequences, cpus=args.jobs)  # type: ignore
815
816
817
818
819
820
821

            for hits in hits_list:
                for hit in hits:
                    if hit.is_reported():
                        print(
                            hit.name.decode(),
                            "-",
822
823
                            hit.best_domain.alignment.hmm_accession.decode(),
                            hit.best_domain.alignment.hmm_name.decode(),
824
825
826
827
828
829
830
                            hit.evalue,
                            hit.score,
                            hit.bias,
                            sep="\t",
                        )

        return 0
831

832
    def _nhmmer(args: argparse.Namespace) -> int:
833
        with SequenceFile(args.seqdb, digital=True) as seqfile:
834
835
            sequences = list(seqfile)

836
        with SequenceFile(args.seqfile, digital=True) as queryfile:
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
            queries = list(queryfile)
            hits_list = nhmmer(queries, sequences, cpus=args.jobs)  # type: ignore
            for hits in hits_list:
                for hit in hits:
                    if hit.is_reported():
                        print(
                            hit.name.decode(),
                            "-",
                            hit.best_domain.alignment.hmm_accession.decode(),
                            hit.best_domain.alignment.hmm_name.decode(),
                            hit.evalue,
                            hit.score,
                            hit.bias,
                            sep="\t",
                        )

        return 0

855
    def _hmmpress(args: argparse.Namespace) -> int:
856
857
858
859
860
861
        for ext in ["h3m", "h3i", "h3f", "h3p"]:
            path = "{}.{}".format(args.hmmfile, ext)
            if os.path.exists(path):
                if args.force:
                    os.remove(path)
                else:
862
863
                    print(f"file {path} already exists")
                    return 1
864
865
866
867

        with HMMFile(args.hmmfile) as hmms:
            hmmpress(hmms, args.hmmfile)

868
869
        return 0

870
    def _hmmalign(args: argparse.Namespace) -> int:
871
872
873
874
875
876
        try:
            with SequenceFile(args.seqfile, args.informat, digital=True) as seqfile:
                sequences: typing.List[DigitalSequence] = list(seqfile)  # type: ignore
        except EOFError as err:
            print(err, file=sys.stderr)
            return 1
877
878
879
880
881
882
883

        with HMMFile(args.hmmfile) as hmms:
            hmm = next(hmms)
            if next(hmms, None) is not None:
                print("HMM file contains more than one HMM, exiting", file=sys.stderr)
                return 1

884
        msa = hmmalign(hmm, sequences, trim=args.trim)
885
886
887
888
889
        if args.output == "-":
            with io.BytesIO() as out:
                msa.write(out, args.outformat)
                print(out.getvalue().decode("ascii"), end="")
        else:
890
            with open(args.output, "wb") as out:
891
892
893
894
                msa.write(out, args.outformat)

        return 0

895
896
    parser = argparse.ArgumentParser()
    parser.add_argument("-j", "--jobs", required=False, default=0, type=int)
897
898
899
    subparsers = parser.add_subparsers(
        dest="cmd", help="HMMER command to run", required=True
    )
900

901
    parser_hmmsearch = subparsers.add_parser("hmmsearch")
902
    parser_hmmsearch.set_defaults(call=_hmmsearch)
903
904
    parser_hmmsearch.add_argument("hmmfile")
    parser_hmmsearch.add_argument("seqdb")
905

906
    parser_phmmer = subparsers.add_parser("phmmer")
907
    parser_phmmer.set_defaults(call=_phmmer)
908
909
910
    parser_phmmer.add_argument("seqfile")
    parser_phmmer.add_argument("seqdb")

911
912
913
914
915
    parser_nhmmer = subparsers.add_parser("nhmmer")
    parser_nhmmer.set_defaults(call=_nhmmer)
    parser_nhmmer.add_argument("seqfile")
    parser_nhmmer.add_argument("seqdb")

916
    parser_hmmpress = subparsers.add_parser("hmmpress")
917
    parser_hmmpress.set_defaults(call=_hmmpress)
918
919
    parser_hmmpress.add_argument("hmmfile")
    parser_hmmpress.add_argument("-f", "--force", action="store_true")
920

921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
    parser_hmmalign = subparsers.add_parser("hmmalign")
    parser_hmmalign.set_defaults(call=_hmmalign)
    parser_hmmalign.add_argument(
        "hmmfile",
        metavar="<hmmfile>"
    )
    parser_hmmalign.add_argument(
        "seqfile",
        metavar="<seqfile>",
    )
    parser_hmmalign.add_argument(
        "-o",
        "--output",
        action="store",
        default="-",
        metavar="<f>",
        help="output alignment to file <f>, not stdout"
    )
    parser_hmmalign.add_argument(
        "--trim",
        action="store_true",
        help="trim terminal tails of nonaligned residues from alignment"
    )
    parser_hmmalign.add_argument(
        "--informat",
        action="store",
        metavar="<s>",
        help="assert <seqfile> is in format <s> (no autodetection)",
949
        choices=SequenceFile._FORMATS.keys(),
950
951
952
953
954
955
956
    )
    parser_hmmalign.add_argument(
        "--outformat",
        action="store",
        metavar="<s>",
        help="output alignment in format <s>",
        default="stockholm",
957
        choices=MSAFile._FORMATS.keys(),
958
959
    )

960
    args = parser.parse_args()
961
    sys.exit(args.call(args))