diff --git a/docs/examples/fetchmgs.ipynb b/docs/examples/fetchmgs.ipynb index e8df382e4d0365d73c2f4b0e6ce19fcf4e4b5e0a..2ee2de0028f6774f8e37699d35a9c95f3d480cff 100644 --- a/docs/examples/fetchmgs.ipynb +++ b/docs/examples/fetchmgs.ipynb @@ -160,7 +160,7 @@ "source": [ "results = []\n", "for hits in pyhmmer.hmmsearch(hmms, proteins, bit_cutoffs=\"trusted\"):\n", - " cog = hits.query_name.decode()\n", + " cog = hits.query.name.decode()\n", " for hit in hits:\n", " if hit.included:\n", " results.append(Result(hit.name.decode(), cog, hit.score))" diff --git a/pyhmmer/__init__.py b/pyhmmer/__init__.py index a9cac1c9952bdeff83150649d0c92f823886a776..c3736ea4fb734b9edbe23d2e877a97c5fdf91747 100644 --- a/pyhmmer/__init__.py +++ b/pyhmmer/__init__.py @@ -17,7 +17,6 @@ References: :doi:`10.1093/bioinformatics/btad214`. :pmid:`37074928`. """ - import collections.abc as _collections_abc import contextlib as _contextlib import os as _os diff --git a/pyhmmer/daemon.pxd b/pyhmmer/daemon.pxd index 199c615943d0519d680b4c87e5f80942ce0f3b46..4aa1f2cd11173a2b800e45736a5cb182b53afb24 100644 --- a/pyhmmer/daemon.pxd +++ b/pyhmmer/daemon.pxd @@ -22,7 +22,7 @@ cdef class Client: cdef bytearray _recvall(self, size_t message_size) cdef TopHits _client( self, - bytes query, + object query, uint64_t db, list ranges, Pipeline pli, diff --git a/pyhmmer/daemon.pyi b/pyhmmer/daemon.pyi index 727612b76b2d7642e9e46fb690e5f7760e17f26c..4a14b4a250ee1039e7606b54ab97592c8752d8f8 100644 --- a/pyhmmer/daemon.pyi +++ b/pyhmmer/daemon.pyi @@ -13,6 +13,8 @@ from pyhmmer.plan7 import TopHits, HMM, Builder BIT_CUTOFFS = Literal["gathering", "trusted", "noise"] +S = typing.TypeVar("S", bound=Sequence) + class Client: address: str port: int @@ -30,7 +32,7 @@ class Client: def close(self) -> None: ... def search_seq( self, - query: Sequence, + query: S, db: int = 1, ranges: typing.Optional[typing.List[typing.Tuple[int, int]]] = None, *, @@ -51,7 +53,7 @@ class Client: incdomE: float = 0.01, incdomT: typing.Optional[float] = None, bit_cutoffs: typing.Optional[BIT_CUTOFFS] = None, - ) -> TopHits: ... + ) -> TopHits[S]: ... def search_hmm( self, query: HMM, @@ -75,10 +77,10 @@ class Client: incdomE: float = 0.01, incdomT: typing.Optional[float] = None, bit_cutoffs: typing.Optional[BIT_CUTOFFS] = None, - ) -> TopHits: ... + ) -> TopHits[HMM]: ... def scan_seq( self, - query: Sequence, + query: S, db: int = 1, *, bias_filter: bool = True, @@ -98,14 +100,14 @@ class Client: incdomE: float = 0.01, incdomT: typing.Optional[float] = None, bit_cutoffs: typing.Optional[BIT_CUTOFFS] = None, - ) -> TopHits: ... + ) -> TopHits[S]: ... def iterate_seq( self, query: Sequence, db: int = 1, ranges: typing.Optional[typing.List[typing.Tuple[int, int]]] = None, builder: typing.Optional[Builder] = None, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[[TopHits[HMM]], None]] = None, *, bias_filter: bool = True, null2: bool = True, @@ -131,7 +133,7 @@ class Client: db: int = 1, ranges: typing.Optional[typing.List[typing.Tuple[int, int]]] = None, builder: typing.Optional[Builder] = None, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[[TopHits[HMM]], None]] = None, *, bias_filter: bool = True, null2: bool = True, @@ -162,6 +164,6 @@ class IterativeSearch(pyhmmer.plan7.IterativeSearch): db: int, builder: Builder, ranges: typing.Optional[typing.List[typing.Tuple[int, int]]] = None, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[[TopHits[HMM]], None]] = None, options: typing.Optional[typing.Dict[str, object]] = None, ) -> None: ... diff --git a/pyhmmer/daemon.pyx b/pyhmmer/daemon.pyx index db31f05cd756350b90a7d2effe843bdd6e86b086..c161e5ca67145cd628bc55f77ddb365fbd156e3d 100644 --- a/pyhmmer/daemon.pyx +++ b/pyhmmer/daemon.pyx @@ -160,7 +160,7 @@ cdef class Client: cdef TopHits _client( self, - bytes query, + object query, uint64_t db, list ranges, Pipeline pli, @@ -190,7 +190,7 @@ cdef class Client: cdef uint32_t hits_start cdef uint32_t buf_offset = 0 - cdef TopHits hits = TopHits() + cdef TopHits hits = TopHits(query) cdef str options = "".join(pli.arguments()) # check ranges argument @@ -207,6 +207,12 @@ cdef class Client: memset(&search_status, 0, sizeof(HMMD_SEARCH_STATUS)) search_stats.hit_offsets = NULL + # serialize query + with io.BytesIO() as buffer: + query.write(buffer) + buffer.write(b"\n//") + txt = buffer.getvalue() + try: # send the options if mode == p7_pipemodes_e.p7_SEARCH_SEQS: @@ -220,7 +226,7 @@ cdef class Client: self.socket.sendall(f"@--hmmdb {db} {options}\n".encode("ascii")) # send the query - self.socket.sendall(query) + self.socket.sendall(txt) # get the search status back response = self._recvall(HMMD_SEARCH_STATUS_SERIAL_SIZE) @@ -344,22 +350,9 @@ cdef class Client: sequence against the sequence database loaded on the server side. """ - cdef bytes txt - cdef TopHits hits cdef Alphabet abc = getattr(query, "alphabet", Alphabet.amino()) cdef Pipeline pli = Pipeline(abc, **options) - - with io.BytesIO() as buffer: - query.write(buffer) - buffer.write(b"\n//") - txt = buffer.getvalue() - - hits = self._client(txt, db, ranges, pli, p7_pipemodes_e.p7_SEARCH_SEQS) - hits._qname = query.name - hits._qacc = query.accession - hits._qlen = len(query) - - return hits + return self._client(query, db, ranges, pli, p7_pipemodes_e.p7_SEARCH_SEQS) def search_hmm( self, @@ -386,21 +379,8 @@ cdef class Client: server side. """ - cdef bytes txt - cdef TopHits hits cdef Pipeline pli = Pipeline(query.alphabet, **options) - - with io.BytesIO() as buffer: - query.write(buffer) - buffer.write(b"\n//") - txt = buffer.getvalue() - - hits = self._client(txt, db, ranges, pli, p7_pipemodes_e.p7_SEARCH_SEQS) - hits._qname = query.name - hits._qacc = query.accession - hits._qlen = query.M - - return hits + return self._client(query, db, ranges, pli, p7_pipemodes_e.p7_SEARCH_SEQS) def scan_seq(self, Sequence query, uint64_t db = 1, **options): """Search the HMMER daemon database with a query sequence. @@ -419,22 +399,9 @@ cdef class Client: server side. """ - cdef bytes txt - cdef TopHits hits cdef Alphabet abc = getattr(query, "alphabet", Alphabet.amino()) cdef Pipeline pli = Pipeline(abc, **options) - - with io.BytesIO() as buffer: - query.write(buffer) - buffer.write(b"\n//") - txt = buffer.getvalue() - - hits = self._client(txt, db, None, pli, p7_pipemodes_e.p7_SCAN_MODELS) - hits._qname = query.name - hits._qacc = query.accession - hits._qlen = len(query) - - return hits + return self._client(query, db, None, pli, p7_pipemodes_e.p7_SCAN_MODELS) def iterate_seq( self, diff --git a/pyhmmer/hmmer.py b/pyhmmer/hmmer.py index 685dc9b5f1382f0b9f80f4b6e22323919a89f915..e9082f85782dbf080a9365ecf474d76aeee0b5c6 100644 --- a/pyhmmer/hmmer.py +++ b/pyhmmer/hmmer.py @@ -8,7 +8,6 @@ Note: threads. """ - import abc import contextlib import collections @@ -23,6 +22,7 @@ import queue import threading import time import typing +from typing import Any import psutil @@ -62,11 +62,21 @@ _T = typing.TypeVar("_T") _R = typing.TypeVar("_R") _I = typing.TypeVar("_I") +# generic profile type +_P = typing.TypeVar("_P", HMM, Profile, OptimizedProfile) +# generic alignment type +_M = typing.TypeVar("_M", DigitalSequence, DigitalMSA) +# generic nucleotide query type +_N = typing.TypeVar("_N", DigitalSequence, DigitalMSA, HMM, Profile, OptimizedProfile) + +# type aliases +_AnyProfile = typing.Union[HMM, Profile, OptimizedProfile] + # the query types for the different tasks _PHMMERQueryType = typing.Union[DigitalSequence, DigitalMSA] -_SEARCHQueryType = typing.Union[HMM, Profile, OptimizedProfile] -_NHMMERQueryType = typing.Union[_PHMMERQueryType, _SEARCHQueryType] -_JACKHMMERQueryType = typing.Union[DigitalSequence, _SEARCHQueryType] +_SEARCHQueryType = _AnyProfile +_NHMMERQueryType = typing.Union[_PHMMERQueryType, _AnyProfile] +_JACKHMMERQueryType = typing.Union[DigitalSequence, _AnyProfile] # `typing.Literal`` is only available in Python 3.8 and later if typing.TYPE_CHECKING: @@ -243,11 +253,11 @@ class _SEARCHWorker( _BaseWorker[ _SEARCHQueryType, typing.Union[DigitalSequenceBlock, SequenceFile], - TopHits, + "TopHits[_SEARCHQueryType]", ] ): @singledispatchmethod - def query(self, query) -> TopHits: # type: ignore + def query(self, query) -> "TopHits[Any]": # type: ignore raise TypeError( "Unsupported query type for `hmmsearch`: {}".format(type(query).__name__) ) @@ -255,7 +265,7 @@ class _SEARCHWorker( @query.register(HMM) @query.register(Profile) @query.register(OptimizedProfile) - def _(self, query: typing.Union[HMM, Profile, OptimizedProfile]) -> TopHits: # type: ignore + def _(self, query: _AnyProfile) -> "TopHits[_AnyProfile]": # type: ignore return self.pipeline.search_hmm(query, self.targets) @@ -263,21 +273,21 @@ class _PHMMERWorker( _BaseWorker[ _PHMMERQueryType, typing.Union[DigitalSequenceBlock, SequenceFile], - TopHits, + "TopHits[_PHMMERQueryType]", ] ): @singledispatchmethod - def query(self, query) -> TopHits: # type: ignore + def query(self, query) -> "TopHits[Any]": # type: ignore raise TypeError( "Unsupported query type for `phmmer`: {}".format(type(query).__name__) ) @query.register(DigitalSequence) - def _(self, query: DigitalSequence) -> TopHits: # type: ignore + def _(self, query: DigitalSequence) -> "TopHits[DigitalSequence]": # type: ignore return self.pipeline.search_seq(query, self.targets, self.builder) @query.register(DigitalMSA) - def _(self, query: DigitalMSA) -> TopHits: # type: ignore + def _(self, query: DigitalMSA) -> "TopHits[DigitalMSA]": # type: ignore return self.pipeline.search_msa(query, self.targets, self.builder) @@ -301,7 +311,7 @@ class _JACKHMMERWorker( alphabet: Alphabet, builder: typing.Optional[Builder] = None, max_iterations: typing.Optional[int] = 5, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None, checkpoints: bool = False, ) -> None: super().__init__( @@ -328,14 +338,20 @@ class _JACKHMMERWorker( @query.register(DigitalSequence) def _(self, query: DigitalSequence) -> typing.Union[IterationResult, typing.Iterable[IterationResult]]: # type: ignore iterator = self.pipeline.iterate_seq( - query, self.targets, self.builder, self.select_hits + query, + self.targets, + self.builder, + self.select_hits # type: ignore ) return self._iterate(iterator, self.checkpoints) @query.register(HMM) def _(self, query: HMM) -> typing.Union[IterationResult, typing.Iterable[IterationResult]]: # type: ignore iterator = self.pipeline.iterate_hmm( - query, self.targets, self.builder, self.select_hits + query, + self.targets, + self.builder, + self.select_hits # type: ignore ) return self._iterate(iterator, self.checkpoints) @@ -373,27 +389,27 @@ class _NHMMERWorker( _BaseWorker[ _NHMMERQueryType, typing.Union[DigitalSequenceBlock, SequenceFile], - TopHits, + "TopHits[_NHMMERQueryType]", ] ): @singledispatchmethod - def query(self, query) -> TopHits: # type: ignore + def query(self, query) -> "TopHits[Any]": # type: ignore raise TypeError( "Unsupported query type for `nhmmer`: {}".format(type(query).__name__) ) @query.register(DigitalSequence) - def _(self, query: DigitalSequence) -> TopHits: # type: ignore + def _(self, query: DigitalSequence) -> "TopHits[DigitalSequence]": # type: ignore return self.pipeline.search_seq(query, self.targets, self.builder) @query.register(DigitalMSA) - def _(self, query: DigitalMSA) -> TopHits: # type: ignore + def _(self, query: DigitalMSA) -> "TopHits[DigitalMSA]": # type: ignore return self.pipeline.search_msa(query, self.targets, self.builder) @query.register(HMM) @query.register(Profile) @query.register(OptimizedProfile) - def _(self, query: typing.Union[HMM, Profile, OptimizedProfile]) -> TopHits: # type: ignore + def _(self, query: _AnyProfile) -> "TopHits[_AnyProfile]": # type: ignore return self.pipeline.search_hmm(query, self.targets) @@ -401,17 +417,17 @@ class _SCANWorker( _BaseWorker[ DigitalSequence, typing.Union[OptimizedProfileBlock, HMMPressedFile], - TopHits, + "TopHits[DigitalSequence]", ] ): @singledispatchmethod - def query(self, query) -> TopHits: # type: ignore + def query(self, query) -> "TopHits[Any]": # type: ignore raise TypeError( "Unsupported query type for `hmmscan`: {}".format(type(query).__name__) ) @query.register(DigitalSequence) - def _(self, query: DigitalSequence) -> TopHits: # type: ignore + def _(self, query: DigitalSequence) -> "TopHits[DigitalSequence]": # type: ignore return self.pipeline.scan_seq(query, self.targets) @@ -544,12 +560,12 @@ class _SEARCHDispatcher( _BaseDispatcher[ _SEARCHQueryType, typing.Union[DigitalSequenceBlock, SequenceFile], - TopHits, + "TopHits[_SEARCHQueryType]", ] ): def _new_thread( self, - query_queue: "queue.Queue[typing.Optional[_Chore[_SEARCHQueryType, TopHits]]]", + query_queue: "queue.Queue[typing.Optional[_Chore[_SEARCHQueryType, TopHits[_SEARCHQueryType]]]]", query_count: "multiprocessing.Value[int]", # type: ignore kill_switch: threading.Event, ) -> _SEARCHWorker: @@ -579,12 +595,12 @@ class _PHMMERDispatcher( _BaseDispatcher[ _PHMMERQueryType, typing.Union[DigitalSequenceBlock, SequenceFile], - TopHits, + "TopHits[_PHMMERQueryType]", ] ): def _new_thread( self, - query_queue: "queue.Queue[typing.Optional[_Chore[_PHMMERQueryType, TopHits]]]", + query_queue: "queue.Queue[typing.Optional[_Chore[_PHMMERQueryType, TopHits[_PHMMERQueryType]]]]", query_count: "multiprocessing.Value[int]", # type: ignore kill_switch: threading.Event, ) -> _PHMMERWorker: @@ -634,7 +650,7 @@ class _JACKHMMERDispatcher( builder: typing.Optional[Builder] = None, timeout: int = 1, max_iterations: typing.Optional[int] = 5, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None, checkpoints: bool = False, **options, # type: object ) -> None: @@ -679,7 +695,7 @@ class _NHMMERDispatcher( _BaseDispatcher[ _NHMMERQueryType, typing.Union[DigitalSequenceBlock, SequenceFile], - TopHits, + "TopHits[_NHMMERQueryType]", ] ): def __init__( @@ -694,7 +710,7 @@ class _NHMMERDispatcher( alphabet: Alphabet = Alphabet.dna(), builder: typing.Optional[Builder] = None, timeout: int = 1, - **options, # type: typing.Dict[str, object] + **options, # type: object ) -> None: super().__init__( queries=queries, @@ -710,7 +726,7 @@ class _NHMMERDispatcher( def _new_thread( self, - query_queue: "queue.Queue[typing.Optional[_Chore[_NHMMERQueryType, TopHits]]]", + query_queue: "queue.Queue[typing.Optional[_Chore[_NHMMERQueryType, TopHits[_NHMMERQueryType]]]]", query_count: "multiprocessing.Value[int]", # type: ignore kill_switch: threading.Event, ) -> _NHMMERWorker: @@ -741,12 +757,12 @@ class _SCANDispatcher( _BaseDispatcher[ DigitalSequence, typing.Union[OptimizedProfileBlock, HMMPressedFile], - TopHits, + "TopHits[DigitalSequence]", ] ): def _new_thread( self, - query_queue: "queue.Queue[typing.Optional[_Chore[DigitalSequence, TopHits]]]", + query_queue: "queue.Queue[typing.Optional[_Chore[DigitalSequence, TopHits[DigitalSequence]]]]", query_count: "multiprocessing.Value[int]", # type: ignore kill_switch: threading.Event, ) -> _SCANWorker: @@ -769,15 +785,14 @@ class _SCANDispatcher( # --- hmmsearch -------------------------------------------------------------- - def hmmsearch( - queries: typing.Union[_SEARCHQueryType, typing.Iterable[_SEARCHQueryType]], + queries: typing.Union[_P, typing.Iterable[_P]], #typing.Union[_SEARCHQueryType, typing.Iterable[_SEARCHQueryType]], sequences: typing.Iterable[DigitalSequence], *, cpus: int = 0, - callback: typing.Optional[typing.Callable[[_SEARCHQueryType, int], None]] = None, - **options, # type: typing.Dict[str, object] -) -> typing.Iterator[TopHits]: + callback: typing.Optional[typing.Callable[[_P, int], None]] = None, + **options, # type: object +) -> typing.Iterator["TopHits[_P]"]: """Search HMM profiles against a sequence database. In HMMER many-to-many comparisons, a *search* is the operation of @@ -870,27 +885,27 @@ def hmmsearch( queries=queries, targets=targets, cpus=_cpus, - callback=callback, + callback=callback, # type: ignore alphabet=alphabet, builder=None, pipeline_class=Pipeline, - **options, + **options, # type: ignore ) - return dispatcher.run() + return dispatcher.run() # type: ignore # --- phmmer ----------------------------------------------------------------- def phmmer( - queries: typing.Union[_PHMMERQueryType, typing.Iterable[_PHMMERQueryType]], + queries: typing.Union[_M, typing.Iterable[_M]], sequences: typing.Iterable[DigitalSequence], *, cpus: int = 0, - callback: typing.Optional[typing.Callable[[_PHMMERQueryType, int], None]] = None, + callback: typing.Optional[typing.Callable[[_M, int], None]] = None, builder: typing.Optional[Builder] = None, - **options, # type: typing.Dict[str, object] -) -> typing.Iterator[TopHits]: + **options, # type: object +) -> typing.Iterator["TopHits[_M]"]: """Search protein sequences against a sequence database. Arguments: @@ -964,13 +979,13 @@ def phmmer( queries=queries, targets=targets, cpus=_cpus, - callback=callback, + callback=callback, # type: ignore pipeline_class=Pipeline, alphabet=alphabet, builder=_builder, - **options, + **options, # type: ignore ) - return dispatcher.run() + return dispatcher.run() # type: ignore # --- jackhmmer ----------------------------------------------------------------- @@ -982,12 +997,12 @@ def jackhmmer( sequences: typing.Iterable[DigitalSequence], *, max_iterations: typing.Optional[int] = 5, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None, checkpoints: "Literal[True]", cpus: int = 0, callback: typing.Optional[typing.Callable[[_JACKHMMERQueryType, int], None]] = None, builder: typing.Optional[Builder] = None, - **options, # type: typing.Dict[str, object] + **options, # type: object ) -> typing.Iterator[typing.Iterable[IterationResult]]: ... @@ -998,12 +1013,12 @@ def jackhmmer( sequences: typing.Iterable[DigitalSequence], *, max_iterations: typing.Optional[int] = 5, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None, checkpoints: "Literal[False]", cpus: int = 0, callback: typing.Optional[typing.Callable[[_JACKHMMERQueryType, int], None]] = None, builder: typing.Optional[Builder] = None, - **options, # type: typing.Dict[str, object] + **options, # type: object ) -> typing.Iterator[IterationResult]: ... @@ -1014,12 +1029,12 @@ def jackhmmer( sequences: typing.Iterable[DigitalSequence], *, max_iterations: typing.Optional[int] = 5, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None, checkpoints: bool = False, cpus: int = 0, callback: typing.Optional[typing.Callable[[_JACKHMMERQueryType, int], None]] = None, builder: typing.Optional[Builder] = None, - **options, # type: typing.Dict[str, object] + **options, # type: object ) -> typing.Union[ typing.Iterator[IterationResult], typing.Iterator[typing.Iterable[IterationResult]] ]: @@ -1031,7 +1046,7 @@ def jackhmmer( sequences: typing.Iterable[DigitalSequence], *, max_iterations: typing.Optional[int] = 5, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None, checkpoints: bool = False, cpus: int = 0, callback: typing.Optional[typing.Callable[[_JACKHMMERQueryType, int], None]] = None, @@ -1130,7 +1145,7 @@ def jackhmmer( alphabet = _alphabet targets = DigitalSequenceBlock(_alphabet, sequences) - dispatcher = _JACKHMMERDispatcher( # type: ignore + dispatcher = _JACKHMMERDispatcher( queries=queries, targets=targets, cpus=_cpus, @@ -1141,7 +1156,7 @@ def jackhmmer( max_iterations=max_iterations, select_hits=select_hits, checkpoints=checkpoints, - **options, + **options, # type: ignore ) return dispatcher.run() @@ -1150,14 +1165,14 @@ def jackhmmer( def nhmmer( - queries: typing.Union[_NHMMERQueryType, typing.Iterable[_NHMMERQueryType]], + queries: typing.Union[_N, typing.Iterable[_N]], sequences: typing.Iterable[DigitalSequence], *, cpus: int = 0, - callback: typing.Optional[typing.Callable[[_NHMMERQueryType, int], None]] = None, + callback: typing.Optional[typing.Callable[[_N, int], None]] = None, builder: typing.Optional[Builder] = None, - **options, # type: typing.Dict[str, object] -) -> typing.Iterator[TopHits]: + **options, # type: object +) -> typing.Iterator["TopHits[_N]"]: """Search nucleotide sequences against a sequence database. Arguments: @@ -1214,9 +1229,9 @@ def nhmmer( if builder is None: _builder = Builder( _alphabet, - seed=options.get("seed", 42), - window_length=options.get("window_length"), - window_beta=options.get("window_beta"), + seed=typing.cast(int, options.get("seed", 42)), + window_length=typing.cast(typing.Optional[int], options.get("window_length")), + window_beta=typing.cast(typing.Optional[float], options.get("window_beta")), ) else: _builder = builder @@ -1244,13 +1259,13 @@ def nhmmer( queries=queries, targets=targets, cpus=_cpus, - callback=callback, + callback=callback, # type: ignore pipeline_class=LongTargetsPipeline, alphabet=alphabet, builder=_builder, - **options, + **options, # type: ignore ) - return dispatcher.run() + return dispatcher.run() # type: ignore # --- hmmpress --------------------------------------------------------------- @@ -1385,8 +1400,8 @@ def hmmscan( cpus: int = 0, callback: typing.Optional[typing.Callable[[DigitalSequence, int], None]] = None, background: typing.Optional[Background] = None, - **options, # type: typing.Dict[str, object] -) -> typing.Iterator[TopHits]: + **options, # type: object +) -> typing.Iterator["TopHits[DigitalSequence]"]: """Scan query sequences against a profile database. In HMMER many-to-many comparisons, a *scan* is the operation of querying @@ -1460,7 +1475,7 @@ def hmmscan( _alphabet = Alphabet.amino() _cpus = cpus if cpus > 0 else psutil.cpu_count(logical=False) or os.cpu_count() or 1 _background = Background(_alphabet) if background is None else background - options.setdefault("background", _background) # type: ignore + options.setdefault("background", _background) if not isinstance(queries, collections.abc.Iterable): queries = (queries,) @@ -1497,7 +1512,7 @@ def hmmscan( pipeline_class=Pipeline, alphabet=alphabet, builder=None, - **options, + **options, # type: ignore ) return dispatcher.run() @@ -1533,8 +1548,8 @@ if __name__ == "__main__": print( hit.name.decode(), (hit.accession or b"-").decode(), - (hits.query_name or b"-").decode(), - (hits.query_accession or b"-").decode(), + (hits.query.name or b"-").decode(), + (hits.query.accession or b"-").decode(), hit.evalue, hit.score, hit.bias, @@ -1595,9 +1610,14 @@ if __name__ == "__main__": sequences = sequences.read_block() # type: ignore # load the query sequences or HMMs iteratively with open_query_file(args.queryfile, alphabet) as queries: - result = jackhmmer(queries, sequences, checkpoint=False, cpus=args.jobs) - for hits in result.hits_list: - for hit in hits: + results = jackhmmer( + typing.cast(typing.Iterable[HMM], queries), + typing.cast(typing.Iterable[DigitalSequence], sequences), + checkpoints=False, + cpus=typing.cast(int, args.jobs) + ) + for result in results: + for hit in result.hits: if hit.reported: print( hit.name.decode(), @@ -1655,8 +1675,8 @@ if __name__ == "__main__": print( hit.name.decode(), (hit.accession or b"-").decode(), - (hits.query_name or b"-").decode(), - (hits.query_accession or b"-").decode(), + (hits.query.name or b"-").decode(), + (hits.query.accession or b"-").decode(), hit.evalue, hit.score, hit.bias, diff --git a/pyhmmer/plan7.pxd b/pyhmmer/plan7.pxd index 137d22795f16d9d12f0e8c7bd361a94f2ab9c10b..2b2e59777df0e8967823860bb054f20904015854 100644 --- a/pyhmmer/plan7.pxd +++ b/pyhmmer/plan7.pxd @@ -408,9 +408,11 @@ cdef class TopHits: # computed and thresholding can be done correctly. cdef P7_PIPELINE _pli cdef P7_TOPHITS* _th - cdef bytes _qname - cdef bytes _qacc - cdef int _qlen + cdef object _query + + # cdef bytes _qname + # cdef bytes _qacc + # cdef int _qlen cdef int _threshold(self, Pipeline pipeline) except 1 nogil cdef int _sort_by_key(self) except 1 nogil diff --git a/pyhmmer/plan7.pyi b/pyhmmer/plan7.pyi index 5c9fd058ac03bb70ca9ef48eba1a65e4025a0407..8f36d431b9f6f79105501bfbaccfb06cd23b3810 100644 --- a/pyhmmer/plan7.pyi +++ b/pyhmmer/plan7.pyi @@ -6,6 +6,7 @@ import os import sys import types import typing +from typing import Any try: from typing import Literal @@ -41,6 +42,8 @@ STRAND_SIGN = Literal["+", "-"] HITS_FORMAT = Literal["targets", "domain", "pfam"] HITS_MODE = Literal["search", "scan"] +Q = typing.TypeVar("Q") # pipeline query type + class Alignment(collections.abc.Sized): domain: Domain def __len__(self) -> int: ... @@ -277,7 +280,7 @@ class EvalueParameters: def as_vector(self) -> VectorF: ... class Hit(object): - hits: TopHits + hits: TopHits[Any] def __getstate__(self) -> typing.Dict[str, object]: ... @property def name(self) -> bytes: ... @@ -460,7 +463,7 @@ class HMMPressedFile(typing.Iterator[OptimizedProfile]): class IterationResult(typing.NamedTuple): hmm: HMM - hits: TopHits + hits: TopHits[HMM] msa: DigitalMSA converged: bool iteration: int @@ -479,7 +482,7 @@ class IterativeSearch(typing.Iterator[IterationResult]): builder: Builder, query: typing.Union[DigitalSequence, HMM], targets: DigitalSequenceBlock, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[[TopHits[HMM]], None]] = None, ) -> None: ... def __iter__(self) -> IterativeSearch: ... def __next__(self) -> IterationResult: ... @@ -726,37 +729,37 @@ class Pipeline(object): self, query: typing.Union[HMM, Profile, OptimizedProfile], sequences: typing.Union[DigitalSequenceBlock, SequenceFile], - ) -> TopHits: ... + ) -> TopHits[typing.Union[HMM, Profile, OptimizedProfile]]: ... def search_msa( self, query: DigitalMSA, sequences: typing.Union[DigitalSequenceBlock, SequenceFile], builder: typing.Optional[Builder] = None, - ) -> TopHits: ... + ) -> TopHits[DigitalMSA]: ... def search_seq( self, query: DigitalSequence, sequences: typing.Union[DigitalSequenceBlock, SequenceFile], builder: typing.Optional[Builder] = None, - ) -> TopHits: ... + ) -> TopHits[DigitalSequence]: ... def scan_seq( self, query: DigitalSequence, optimized_profiles: typing.Union[OptimizedProfileBlock, HMMPressedFile], - ) -> TopHits: ... + ) -> TopHits[DigitalSequence]: ... def iterate_seq( self, query: DigitalSequence, sequences: DigitalSequenceBlock, builder: typing.Optional[Builder] = None, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[[TopHits[DigitalSequence]], None]] = None, ) -> IterativeSearch: ... def iterate_hmm( self, query: HMM, sequences: DigitalSequenceBlock, builder: typing.Optional[Builder] = None, - select_hits: typing.Optional[typing.Callable[[TopHits], None]] = None, + select_hits: typing.Optional[typing.Callable[[TopHits[DigitalSequence]], None]] = None, ) -> IterativeSearch: ... class LongTargetsPipeline(Pipeline): @@ -853,17 +856,17 @@ class ScoreData(object): def __copy__(self) -> ScoreData: ... def copy(self) -> ScoreData: ... -class TopHits(typing.Sequence[Hit]): - def __init__(self) -> None: ... +class TopHits(typing.Sequence[Hit], typing.Generic[Q]): + def __init__(self, query: Q) -> None: ... def __bool__(self) -> bool: ... - def __copy__(self) -> TopHits: ... - def __deepcopy__(self, memo: typing.Dict[int, object]) -> TopHits: ... + def __copy__(self) -> TopHits[Q]: ... + def __deepcopy__(self, memo: typing.Dict[int, object]) -> TopHits[Q]: ... def __len__(self) -> int: ... @typing.overload def __getitem__(self, index: int) -> Hit: ... @typing.overload def __getitem__(self, index: slice) -> typing.Sequence[Hit]: ... - def __iadd__(self, other: TopHits) -> TopHits: ... + def __iadd__(self, other: TopHits[Q]) -> TopHits[Q]: ... def __getstate__(self) -> typing.Dict[str, object]: ... def __setstate__(self, state: typing.Dict[str, object]) -> None: ... @property @@ -875,6 +878,8 @@ class TopHits(typing.Sequence[Hit]): @property def query_length(self) -> int: ... @property + def query(self) -> Q: ... + @property def Z(self) -> float: ... @property def domZ(self) -> float: ... @@ -917,8 +922,8 @@ class TopHits(typing.Sequence[Hit]): def compare_ranking(self, ranking: KeyHash) -> int: ... def sort(self, by: SORT_KEY = "key") -> None: ... def is_sorted(self, by: SORT_KEY = "key") -> bool: ... - def copy(self) -> TopHits: ... - def merge(self, *others: TopHits) -> TopHits: ... + def copy(self) -> TopHits[Q]: ... + def merge(self, *others: TopHits[Q]) -> TopHits[Q]: ... def to_msa( self, alphabet: Alphabet, diff --git a/pyhmmer/plan7.pyx b/pyhmmer/plan7.pyx index 22a1c19b6957eee85d52511b7c05d2c0a5e9e26d..c78060dcb1bcef35b309387ce7ed407225c630d0 100644 --- a/pyhmmer/plan7.pyx +++ b/pyhmmer/plan7.pyx @@ -5556,7 +5556,6 @@ cdef class Pipeline: if self.profile._gm == NULL: raise AllocationError("P7_PROFILE", sizeof(P7_OPROFILE)) else: - self.profile.clear() # configure the profile from the query HMM self.profile.configure(<HMM> query, self.background, L) @@ -5764,7 +5763,7 @@ cdef class Pipeline: cdef P7_OPROFILE* om cdef int status cdef int allocM - cdef TopHits hits = TopHits() + cdef TopHits hits = TopHits(query) # check that the sequence file is in digital mode if SearchTargets is SequenceFile: @@ -5816,12 +5815,7 @@ cdef class Pipeline: hits._threshold(self) # record the query metadata - hits._qlen = om.M - if om.name != NULL: - hits._qname = PyBytes_FromString(om.name) - if om.acc != NULL: - hits._qacc = PyBytes_FromString(om.acc) - + hits._query = query # return the hits return hits @@ -5877,6 +5871,7 @@ cdef class Pipeline: cdef HMM hmm cdef OptimizedProfile opt cdef Profile profile + cdef TopHits hits # check the pipeline was configured with the same alphabet if not self.alphabet._eq(query.alphabet): @@ -5886,12 +5881,15 @@ cdef class Pipeline: # build the HMM and the profile from the query MSA hmm, profile, opt = builder.build_msa(query, self.background) if isinstance(sequences, DigitalSequenceBlock): - return self.search_hmm[DigitalSequenceBlock](opt, sequences) + hits = self.search_hmm[DigitalSequenceBlock](opt, sequences) elif isinstance(sequences, SequenceFile): - return self.search_hmm[SequenceFile](opt, sequences) + hits = self.search_hmm[SequenceFile](opt, sequences) else: ty = type(sequences).__name__ raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}") + # record query metadata + hits._query = query + return hits cpdef TopHits search_seq( self, @@ -5938,6 +5936,7 @@ cdef class Pipeline: cdef HMM hmm cdef OptimizedProfile opt cdef Profile profile + cdef TopHits hits # check the pipeline was configure with the same alphabet if not self.alphabet._eq(query.alphabet): @@ -5949,12 +5948,15 @@ cdef class Pipeline: # build the HMM and the profile from the query sequence hmm, profile, opt = builder.build(query, self.background) if isinstance(sequences, DigitalSequenceBlock): - return self.search_hmm[DigitalSequenceBlock](opt, sequences) + hits = self.search_hmm[DigitalSequenceBlock](opt, sequences) elif isinstance(sequences, SequenceFile): - return self.search_hmm[SequenceFile](opt, sequences) + hits = self.search_hmm[SequenceFile](opt, sequences) else: ty = type(sequences).__name__ raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}") + # record query metadata + hits._query = query + return hits @staticmethod cdef int _search_loop( @@ -6139,7 +6141,7 @@ cdef class Pipeline: """ cdef int allocM cdef Profile profile - cdef TopHits hits = TopHits() + cdef TopHits hits = TopHits(query) assert self._pli != NULL @@ -6182,12 +6184,7 @@ cdef class Pipeline: hits._threshold(self) # record the query metadata - hits._qlen = query._sq.L - if query._sq.name != NULL: - hits._qname = PyBytes_FromString(query._sq.name) - if query._sq.acc != NULL: - hits._qacc = PyBytes_FromString(query._sq.acc) - + hits._query = query # return the hits return hits @@ -6843,7 +6840,7 @@ cdef class LongTargetsPipeline(Pipeline): cdef HMM hmm cdef int max_length cdef ScoreData scoredata = ScoreData.__new__(ScoreData) - cdef TopHits hits = TopHits() + cdef TopHits hits = TopHits(query) cdef P7_HIT* hit = NULL cdef P7_OPROFILE* om = NULL @@ -6938,12 +6935,7 @@ cdef class LongTargetsPipeline(Pipeline): hits._pli.pos_output += 1 + llabs(hits._th.hit[j].dcl[0].jali - hits._th.hit[j].dcl[0].iali) # record the query metadata - hits._qlen = om.M - if om.name != NULL: - hits._qname = PyBytes_FromString(om.name) - if om.acc != NULL: - hits._qacc = PyBytes_FromString(om.acc) - + hits._query = query # return the hits return hits @@ -6981,6 +6973,10 @@ cdef class LongTargetsPipeline(Pipeline): """ assert self._pli != NULL + + cdef HMM hmm + cdef TopHits hits + if not self.alphabet._eq(query.alphabet): raise AlphabetMismatch(self.alphabet, query.alphabet) @@ -6996,16 +6992,19 @@ cdef class LongTargetsPipeline(Pipeline): elif builder.window_beta != self.window_beta: raise ValueError("builder and long targets pipeline have different window beta") - cdef HMM hmm = builder.build(query, self.background)[0] + hmm = builder.build(query, self.background)[0] assert hmm._hmm.max_length != -1 if isinstance(sequences, DigitalSequenceBlock): - return self.search_hmm[DigitalSequenceBlock](hmm, sequences) + hits = self.search_hmm[DigitalSequenceBlock](hmm, sequences) elif isinstance(sequences, SequenceFile): - return self.search_hmm[SequenceFile](hmm, sequences) + hits = self.search_hmm[SequenceFile](hmm, sequences) else: ty = type(sequences).__name__ raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}") + hits._query = query + return hits + cpdef TopHits search_msa( self, DigitalMSA query, @@ -7042,6 +7041,9 @@ cdef class LongTargetsPipeline(Pipeline): """ assert self._pli != NULL + cdef HMM hmm + cdef TopHits hits + if not self.alphabet._eq(query.alphabet): raise AlphabetMismatch(self.alphabet, query.alphabet) @@ -7057,15 +7059,18 @@ cdef class LongTargetsPipeline(Pipeline): elif builder.window_beta != self.window_beta: raise ValueError("builder and long targets pipeline have different window beta") - cdef HMM hmm = builder.build_msa(query, self.background)[0] + hmm = builder.build_msa(query, self.background)[0] assert hmm._hmm.max_length != -1 if isinstance(sequences, DigitalSequenceBlock): - return self.search_hmm[DigitalSequenceBlock](hmm, sequences) + hits = self.search_hmm[DigitalSequenceBlock](hmm, sequences) elif isinstance(sequences, SequenceFile): - return self.search_hmm[SequenceFile](hmm, sequences) + hits = self.search_hmm[SequenceFile](hmm, sequences) else: ty = type(sequences).__name__ raise TypeError(f"Expected DigitalSequenceBlock or SequenceFile, found {ty}") + + hits._query = query + return hits @staticmethod cdef int _search_loop_longtargets( @@ -7672,17 +7677,16 @@ cdef class TopHits: def __cinit__(self): self._th = NULL - self._qname = None - self._qacc = None - self._qlen = -1 + self._query = None memset(&self._pli, 0, sizeof(P7_PIPELINE)) - def __init__(self): - """__init__(self)\n--\n + def __init__(self, object query not None): + """__init__(self, query)\n--\n Create an empty `TopHits` instance. """ + self._query = query with nogil: # free allocated memory (in case __init__ is called more than once) libhmmer.p7_tophits.p7_tophits_Destroy(self._th) @@ -7724,7 +7728,7 @@ cdef class TopHits: return self.merge(other) def __reduce__(self): - return TopHits, (), self.__getstate__() + return TopHits, (self.query,), self.__getstate__() def __getstate__(self): assert self._th != NULL @@ -7746,9 +7750,6 @@ cdef class TopHits: hits.append(offset) return { - "qname": self._qname, - "qacc": self._qacc, - "qlen": self._qlen, "unsrt": unsrt, "hit": hits, "Nalloc": self._th.Nalloc, @@ -7813,11 +7814,6 @@ cdef class TopHits: cdef size_t offset cdef VectorU8 hit_state - # record query name and accession - self._qname = state["qname"] - self._qacc = state["qacc"] - self._qlen = state["qlen"] - # deallocate current data if needed if self._th != NULL: libhmmer.p7_tophits.p7_tophits_Destroy(self._th) @@ -7924,8 +7920,19 @@ cdef class TopHits: .. versionadded:: 0.6.1 + .. deprecated:: 0.10.10 + Use ``TopHits.query`` to access the original query directly. + """ - return self._qname + warnings.warn( + "TopHits.query_name has been deprecated in v0.10.10 and will be " + "removed in v0.11.0, use TopHits.query to access the properties of " + "the original query object directly", + DeprecationWarning, + ) + if self._query is None: + return None + return self._query.name @property def query_accession(self): @@ -7933,8 +7940,19 @@ cdef class TopHits: .. versionadded:: 0.6.1 + .. deprecated:: 0.10.10 + Use ``TopHits.query`` to access the original query directly. + """ - return self._qacc + warnings.warn( + "TopHits.query_accession has been deprecated in v0.10.10 and will be " + "removed in v0.11.0, use TopHits.query to access the properties of " + "the original query object directly", + DeprecationWarning, + ) + if self._query is None: + return None + return self._query.accession @property def query_length(self): @@ -7942,8 +7960,36 @@ cdef class TopHits: .. versionadded:: 0.10.5 + .. deprecated:: 0.10.10 + Use ``TopHits.query`` to access the original query directly. + """ - return self._qlen + warnings.warn( + "TopHits.query_length has been deprecated in v0.10.10 and will be " + "removed in v0.11.0, use TopHits.query to access the properties of " + "the original query object directly", + DeprecationWarning, + ) + if self._query is None: + return 0 + return self._query.M if isinstance(self._query, HMM) else len(self._query) + + @property + def query(self): + """`object`: The query object these hits were obtained for. + + The actual type of `TopHits.query` depends on the query that was given + to the `Pipeline`, or the `~pyhmmer.hmmer` function, that created the + object:: + + >>> hits = next(pyhmmer.hmmsearch(thioesterase, proteins)) + >>> hits.query + <HMM alphabet=Alphabet.amino() M=243 name=b'Thioesterase'> + + .. versionadded 0.10.10 + + """ + return self._query @property def Z(self): @@ -8184,7 +8230,7 @@ cdef class TopHits: raise ValueError("Trying to merge `TopHits` obtained from pipelines manually configured to different `domZ` values.") # check threshold modes are consistent if self._pli.by_E != other.by_E: - raise ValueError("Trying to merge `TopHits` obtained from pipelines with different reporting threshold modes") + raise ValueError(f"Trying to merge `TopHits` obtained from pipelines with different reporting threshold modes: {self._pli.by_E} != {other.by_E}") elif self._pli.dom_by_E != other.dom_by_E: raise ValueError("Trying to merge `TopHits` obtained from pipelines with different domain reporting threshold modes") elif self._pli.inc_by_E != other.inc_by_E: @@ -8215,9 +8261,7 @@ cdef class TopHits: cdef TopHits copy = TopHits.__new__(TopHits) # record query metatada - copy._qname = self._qname - copy._qacc = self._qacc - copy._qlen = self._qlen + copy._query = self._query with nogil: # copy pipeline configuration @@ -8443,9 +8487,14 @@ cdef class TopHits: cdef FILE* file cdef str fname cdef int status - cdef char* unk = b"-" - cdef char* qname = unk if self._qname is None else <char*> self._qname - cdef char* qacc = unk if self._qacc is None else <char*> self._qacc + cdef bytes qname = b"-" + cdef bytes qacc = b"-" + + if self._query is not None: + if self._query.name is not None: + qname = self._query.name + if self._query.accession is not None: + qacc = self._query.accession file = fopen_obj(fh, "w") try: @@ -8536,19 +8585,17 @@ cdef class TopHits: # not referenced anywhere else) other_copy = other.copy() + # check that names/accessions are consistent + if merged._query != other._query: + raise ValueError("Trying to merge `TopHits` obtained from different queries") + # just store the copy if merging inside an empty uninitialized `TopHits` - if merged._th.N == 0 and merged._qname is None and merged._qacc is None: - merged._qname = other._qname - merged._qacc = other._qacc - merged._qlen = other._qlen + if merged._th.N == 0: + merged._query = other._query memcpy(&merged._pli, &other_copy._pli, sizeof(P7_PIPELINE)) merged._th, other_copy._th = other_copy._th, merged._th continue - # check that names/accessions are consistent - if merged._qname != other._qname or merged._qacc != other._qacc or merged._qlen != other._qlen: - raise ValueError("Trying to merge `TopHits` obtained from different queries") - # check that the parameters are the same merged._check_threshold_parameters(&other._pli) diff --git a/pyhmmer/tests/test_doctest.py b/pyhmmer/tests/test_doctest.py index 27e1c1d4fe8b684b0e411b5aed63c5e21f127099..1ea4e03cf1579249a3100856c40bdc35867da24a 100644 --- a/pyhmmer/tests/test_doctest.py +++ b/pyhmmer/tests/test_doctest.py @@ -45,22 +45,6 @@ def load_tests(loader, tests, ignore): _current_cwd = os.getcwd() _daemon_client = mock.patch("pyhmmer.daemon.Client") - def setUp(self): - warnings.simplefilter("ignore") - os.chdir(os.path.realpath(os.path.join(__file__, "..", ".."))) - # mock the HMMPGMD client to show usage examples without having - # to actually spawn an HMMPGMD server in the background - _client = _daemon_client.__enter__() - _client.return_value = _client - _client.__enter__.return_value = _client - _client.connect.return_value = None - _client.search_hmm.return_value = pyhmmer.plan7.TopHits() - - def tearDown(self): - os.chdir(_current_cwd) - warnings.simplefilter(warnings.defaultaction) - _daemon_client.__exit__(None, None, None) - # doctests are not compatible with `green`, so we may want to bail out # early if `green` is running the tests if sys.argv[0].endswith("green"): @@ -86,6 +70,22 @@ def load_tests(loader, tests, ignore): with pyhmmer.easel.SequenceFile(seq_path, digital=True) as seq_file: reductase = next(seq for seq in seq_file if b"P12748" in seq.name) + def setUp(self): + warnings.simplefilter("ignore") + os.chdir(os.path.realpath(os.path.join(__file__, "..", ".."))) + # mock the HMMPGMD client to show usage examples without having + # to actually spawn an HMMPGMD server in the background + _client = _daemon_client.__enter__() + _client.return_value = _client + _client.__enter__.return_value = _client + _client.connect.return_value = None + _client.search_hmm.return_value = pyhmmer.plan7.TopHits(thioesterase) + + def tearDown(self): + os.chdir(_current_cwd) + warnings.simplefilter(warnings.defaultaction) + _daemon_client.__exit__(None, None, None) + # recursively traverse all library submodules and load tests from them packages = [None, pyhmmer] diff --git a/pyhmmer/tests/test_hmmer.py b/pyhmmer/tests/test_hmmer.py index b13eba1f7f48fe48e9280369e997c1746322019f..9bf2bf7028968f949a1476b3b14bb85350e2e52a 100644 --- a/pyhmmer/tests/test_hmmer.py +++ b/pyhmmer/tests/test_hmmer.py @@ -764,7 +764,7 @@ class TestHMMScan(unittest.TestCase): seqs_path, digital=True, alphabet=hmms[0].alphabet ) as seqs_file: for hits in pyhmmer.hmmer.hmmscan(seqs_file, hmms, cpus=1): - expected_lines = expected.get(hits.query_name.decode()) + expected_lines = expected.get(hits.query.name.decode()) if expected_lines is None: self.assertEqual(len(hits), 0) continue @@ -798,7 +798,7 @@ class TestHMMScan(unittest.TestCase): with SequenceFile(seqs_path, digital=True) as seqs_file: with HMMPressedFile(db_path) as pressed_file: for hits in pyhmmer.hmmer.hmmscan(seqs_file, pressed_file, cpus=1): - expected_lines = expected.get(hits.query_name.decode()) + expected_lines = expected.get(hits.query.name.decode()) if expected_lines is None: self.assertEqual(len(hits), 0) continue diff --git a/pyhmmer/tests/test_plan7/test_alignment.py b/pyhmmer/tests/test_plan7/test_alignment.py index 4e86fb9a1e5df0aefdff47777f89ddfb062cd2ba..adcd672e29f293e7373dfd1098379faa56cf0254 100644 --- a/pyhmmer/tests/test_plan7/test_alignment.py +++ b/pyhmmer/tests/test_plan7/test_alignment.py @@ -37,7 +37,7 @@ class TestAlignment(unittest.TestCase): rendered = str(self.ali) lines = rendered.splitlines() self.assertEqual(len(lines), 5) - self.assertTrue(lines[1].strip().startswith(self.hits.query_name.decode())) + self.assertTrue(lines[1].strip().startswith(self.hits.query.name.decode())) self.assertTrue(lines[3].strip().startswith(self.hits[0].name.decode())) @unittest.skipIf(sys.implementation.name == "pypy", "`getsizeof` not supported on PyPY") diff --git a/pyhmmer/tests/test_plan7/test_pipeline.py b/pyhmmer/tests/test_plan7/test_pipeline.py index 594ae9b9378fff33301f8a727a141bd3369d6b86..139694bc01f5d5175ef34add9ebc57d2b3e49e9e 100644 --- a/pyhmmer/tests/test_plan7/test_pipeline.py +++ b/pyhmmer/tests/test_plan7/test_pipeline.py @@ -72,8 +72,8 @@ class TestSearchPipeline(unittest.TestCase): pipeline = Pipeline(alphabet=self.alphabet) hits = pipeline.search_hmm(hmm, self.references) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, hmm.name) - self.assertEqual(hits.query_accession, hmm.accession) + self.assertEqual(hits.query.name, hmm.name) + self.assertEqual(hits.query.accession, hmm.accession) def test_search_hmm_file(self): seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1") @@ -83,8 +83,8 @@ class TestSearchPipeline(unittest.TestCase): with SequenceFile(self.reference_path, digital=True, alphabet=self.alphabet) as seqs_file: hits = pipeline.search_hmm(hmm, seqs_file) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, hmm.name) - self.assertEqual(hits.query_accession, hmm.accession) + self.assertEqual(hits.query.name, hmm.name) + self.assertEqual(hits.query.accession, hmm.accession) def test_search_hmm_unnamed(self): # make sure `Pipeline.search_hmm` doesn't crash when given an HMM with no name @@ -94,16 +94,16 @@ class TestSearchPipeline(unittest.TestCase): hmm.accession = None pipeline = Pipeline(alphabet=self.alphabet) hits = pipeline.search_hmm(hmm, self.references) - self.assertEqual(hits.query_name, b"test") - self.assertIs(hits.query_accession, None) + self.assertEqual(hits.query.name, b"test") + self.assertIs(hits.query.accession, None) def test_search_seq_block(self): seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1", accession=b"SQ001") pipeline = Pipeline(alphabet=self.alphabet) hits = pipeline.search_seq(seq.digitize(self.alphabet), self.references) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, seq.name) - # self.assertEqual(hits.query_accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession... + self.assertEqual(hits.query.name, seq.name) + self.assertEqual(hits.query.accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession... def test_search_seq_file(self): seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1", accession=b"SQ001") @@ -111,23 +111,23 @@ class TestSearchPipeline(unittest.TestCase): with SequenceFile(self.reference_path, digital=True, alphabet=self.alphabet) as seqs_file: hits = pipeline.search_seq(seq.digitize(self.alphabet), seqs_file) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, seq.name) - # self.assertEqual(hits.query_accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession... + self.assertEqual(hits.query.name, seq.name) + self.assertEqual(hits.query.accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession... def test_search_msa_block(self): pipeline = Pipeline(alphabet=self.alphabet) hits = pipeline.search_msa(self.msa, self.references) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, self.msa.name) - self.assertEqual(hits.query_accession, self.msa.accession) + self.assertEqual(hits.query.name, self.msa.name) + self.assertEqual(hits.query.accession, self.msa.accession) def test_search_msa_file(self): pipeline = Pipeline(alphabet=self.alphabet) with SequenceFile(self.reference_path, digital=True, alphabet=self.alphabet) as seqs_file: hits = pipeline.search_msa(self.msa, seqs_file) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, self.msa.name) - self.assertEqual(hits.query_accession, self.msa.accession) + self.assertEqual(hits.query.name, self.msa.name) + self.assertEqual(hits.query.accession, self.msa.accession) def test_Z(self): seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1") @@ -330,8 +330,8 @@ class TestLongTargetsPipeline(unittest.TestCase): pipeline = LongTargetsPipeline(alphabet=dna) hits = pipeline.search_hmm(hmm, targets) - self.assertEqual(hits.query_name, hmm.name) - self.assertEqual(hits.query_accession, hmm.accession) + self.assertEqual(hits.query.name, hmm.name) + self.assertEqual(hits.query.accession, hmm.accession) def test_search_hmm_file(self): dna = Alphabet.dna() @@ -350,8 +350,8 @@ class TestLongTargetsPipeline(unittest.TestCase): pipeline = LongTargetsPipeline(alphabet=dna) hits = pipeline.search_hmm(hmm, targets) - self.assertEqual(hits.query_name, hmm.name) - self.assertEqual(hits.query_accession, hmm.accession) + self.assertEqual(hits.query.name, hmm.name) + self.assertEqual(hits.query.accession, hmm.accession) def test_search_hmm_alphabet_mismatch(self): dna = Alphabet.dna() diff --git a/pyhmmer/tests/test_plan7/test_tophits.py b/pyhmmer/tests/test_plan7/test_tophits.py index 2403922c5b27554e75b6c2dca4f960e4ced1464e..7a6ecfc4da050acacb0da060dd2440456f186819 100644 --- a/pyhmmer/tests/test_plan7/test_tophits.py +++ b/pyhmmer/tests/test_plan7/test_tophits.py @@ -83,9 +83,9 @@ class TestTopHits(unittest.TestCase): self.assertEqual(sum(d.included for d in h1.domains), sum(d.included for d in h2.domains)) def assertHitsEqual(self, hits1, hits2): - self.assertEqual(hits1.query_name, hits2.query_name) - self.assertEqual(hits1.query_accession, hits2.query_accession) - self.assertEqual(hits1.query_length, hits2.query_length) + self.assertEqual(hits1.query.name, hits2.query.name) + self.assertEqual(hits1.query.accession, hits2.query.accession) + self.assertEqual(hits1.query, hits2.query) self.assertEqual(len(hits1), len(hits2)) self.assertEqual(len(hits1.included), len(hits2.included)) self.assertEqual(len(hits1.reported), len(hits2.reported)) @@ -114,7 +114,7 @@ class TestTopHits(unittest.TestCase): self.assertEqual(search_hits.mode, "search") def test_bool(self): - self.assertFalse(pyhmmer.plan7.TopHits()) + self.assertFalse(pyhmmer.plan7.TopHits(self.hmm)) self.assertTrue(self.hits) def test_index_error(self): @@ -131,36 +131,36 @@ class TestTopHits(unittest.TestCase): self.assertEqual(dom.name, dom_last.name) def test_Z(self): - empty = TopHits() + empty = TopHits(self.hmm) self.assertEqual(empty.Z, 0) self.assertEqual(self.hits.Z, len(self.seqs)) def test_strand(self): - empty = TopHits() + empty = TopHits(self.hmm) self.assertIs(empty.strand, None) def test_searched_sequences(self): - empty = TopHits() + empty = TopHits(self.hmm) self.assertEqual(empty.searched_sequences, 0) self.assertEqual(self.hits.searched_sequences, len(self.seqs)) def test_searched_nodes(self): - empty = TopHits() + empty = TopHits(self.hmm) self.assertEqual(empty.searched_nodes, 0) self.assertEqual(self.hits.searched_nodes, self.hmm.M) def test_searched_residues(self): - empty = TopHits() + empty = TopHits(self.hmm) self.assertEqual(empty.searched_residues, 0) self.assertEqual(self.hits.searched_residues, sum(map(len, self.seqs))) def test_searched_models(self): - empty = TopHits() + empty = TopHits(self.hmm) self.assertEqual(empty.searched_sequences, 0) self.assertEqual(self.hits.searched_models, 1) def test_merge_empty(self): - empty = TopHits() + empty = TopHits(self.hmm) self.assertFalse(empty.long_targets) self.assertEqual(empty.Z, 0.0) self.assertEqual(empty.domZ, 0.0) @@ -170,7 +170,7 @@ class TestTopHits(unittest.TestCase): self.assertEqual(empty2.Z, 0.0) self.assertEqual(empty2.domZ, 0.0) - merged_empty = empty.merge(TopHits()) + merged_empty = empty.merge(TopHits(self.hmm)) self.assertHitsEqual(merged_empty, empty) self.assertEqual(merged_empty.searched_residues, 0) self.assertEqual(merged_empty.searched_sequences, 0) @@ -180,9 +180,9 @@ class TestTopHits(unittest.TestCase): self.assertEqual(merged_empty.domZ, 0.0) merged = empty.merge(self.hits) - self.assertEqual(merged.query_name, self.hits.query_name) - self.assertEqual(merged.query_length, self.hits.query_length) - self.assertEqual(merged.query_accession, self.hits.query_accession) + self.assertEqual(merged.query.name, self.hits.query.name) + self.assertEqual(merged.query.M, self.hits.query.M) + self.assertEqual(merged.query.accession, self.hits.query.accession) self.assertEqual(merged.E, self.hits.E) self.assertHitsEqual(merged, self.hits) @@ -206,9 +206,10 @@ class TestTopHits(unittest.TestCase): self.assertEqual(merged.searched_sequences, hits.searched_sequences) self.assertEqual(merged.Z, hits.Z) self.assertEqual(merged.domZ, hits.domZ) - self.assertEqual(merged.query_name, hits.query_name) - self.assertEqual(merged.query_length, hits.query_length) - self.assertEqual(merged.query_accession, hits.query_accession) + self.assertIs(merged.query, hits.query) + self.assertEqual(merged.query.name, hits.query.name) + self.assertEqual(merged.query.M, hits.query.M) + self.assertEqual(merged.query.accession, hits.query.accession) self.assertEqual(merged.E, hits.E) self.assertEqual(merged.domE, hits.domE) @@ -244,9 +245,10 @@ class TestTopHits(unittest.TestCase): self.assertEqual(merged.searched_sequences, hits.searched_sequences) self.assertEqual(merged.Z, hits.Z) self.assertEqual(merged.domZ, hits.domZ) - self.assertEqual(merged.query_name, hits.query_name) - self.assertEqual(merged.query_length, hits.query_length) - self.assertEqual(merged.query_accession, hits.query_accession) + self.assertIs(merged.query, hits.query) + self.assertEqual(merged.query.name, hits.query.name) + self.assertEqual(merged.query.M, hits.query.M) + self.assertEqual(merged.query.accession, hits.query.accession) self.assertEqual(merged.E, hits.E) self.assertEqual(merged.domE, hits.domE) @@ -340,14 +342,17 @@ class TestTopHits(unittest.TestCase): pickled = pickle.loads(pickle.dumps(self.hits)) self.assertHitsEqual(pickled, self.hits) + def test_query(self): + self.assertIs(self.hits.query, self.hmm) + def test_query_name(self): - self.assertEqual(self.hits.query_name, self.hmm.name) + self.assertEqual(self.hits.query.name, self.hmm.name) def test_query_accession(self): - self.assertEqual(self.hits.query_accession, self.hmm.accession) + self.assertEqual(self.hits.query.accession, self.hmm.accession) def test_query_length(self): - self.assertEqual(self.hits.query_length, self.hmm.M) + self.assertEqual(self.hits.query.M, self.hmm.M) def test_write_target(self): buffer = io.BytesIO() diff --git a/pyhmmer/utils.py b/pyhmmer/utils.py index ebeaefc4c8cc05d0c4c3a586d6d71e794b01a2b2..0d078f89e7c4d1bdfaffdade8ee09aea093cf80b 100644 --- a/pyhmmer/utils.py +++ b/pyhmmer/utils.py @@ -75,25 +75,33 @@ class singledispatchmethod(typing.Generic[_T]): return _method @typing.overload - def register( - self, cls: typing.Type[typing.Any], method: None = ... + def register( # type: ignore + self, + cls: typing.Type[typing.Any], + method: None = ... ) -> typing.Callable[[typing.Callable[..., _T]], typing.Callable[..., _T]]: ... @typing.overload - def register( - self, cls: typing.Callable[..., _T], method: None = ... + def register( # type: ignore + self, + cls: typing.Callable[..., _T], + method: None = ... ) -> typing.Callable[..., _T]: ... @typing.overload - def register( - self, cls: typing.Type[typing.Any], method: typing.Callable[..., _T] + def register( # type: ignore + self, + cls: typing.Type[typing.Any], + method: typing.Callable[..., _T] ) -> typing.Callable[..., _T]: ... def register( - self, cls: typing.Any, method: typing.Optional[typing.Callable[..., _T]] = None + self, + cls: typing.Any, + method: typing.Optional[typing.Callable[..., _T]] = None ) -> typing.Any: """Registers a new implementation for the given class.""" return self.dispatcher.register(cls, func=method)