diff --git a/pyhmmer/hmmer.py b/pyhmmer/hmmer.py index 5a3a0c5662129553ffda97aa0a83922f3ced2913..e9082f85782dbf080a9365ecf474d76aeee0b5c6 100644 --- a/pyhmmer/hmmer.py +++ b/pyhmmer/hmmer.py @@ -62,6 +62,13 @@ _T = typing.TypeVar("_T") _R = typing.TypeVar("_R") _I = typing.TypeVar("_I") +# generic profile type +_P = typing.TypeVar("_P", HMM, Profile, OptimizedProfile) +# generic alignment type +_M = typing.TypeVar("_M", DigitalSequence, DigitalMSA) +# generic nucleotide query type +_N = typing.TypeVar("_N", DigitalSequence, DigitalMSA, HMM, Profile, OptimizedProfile) + # type aliases _AnyProfile = typing.Union[HMM, Profile, OptimizedProfile] @@ -778,15 +785,14 @@ class _SCANDispatcher( # --- hmmsearch -------------------------------------------------------------- - def hmmsearch( - queries: typing.Union[_SEARCHQueryType, typing.Iterable[_SEARCHQueryType]], + queries: typing.Union[_P, typing.Iterable[_P]], #typing.Union[_SEARCHQueryType, typing.Iterable[_SEARCHQueryType]], sequences: typing.Iterable[DigitalSequence], *, cpus: int = 0, - callback: typing.Optional[typing.Callable[[_SEARCHQueryType, int], None]] = None, + callback: typing.Optional[typing.Callable[[_P, int], None]] = None, **options, # type: object -) -> typing.Iterator["TopHits[_SEARCHQueryType]"]: +) -> typing.Iterator["TopHits[_P]"]: """Search HMM profiles against a sequence database. In HMMER many-to-many comparisons, a *search* is the operation of @@ -879,27 +885,27 @@ def hmmsearch( queries=queries, targets=targets, cpus=_cpus, - callback=callback, + callback=callback, # type: ignore alphabet=alphabet, builder=None, pipeline_class=Pipeline, **options, # type: ignore ) - return dispatcher.run() + return dispatcher.run() # type: ignore # --- phmmer ----------------------------------------------------------------- def phmmer( - queries: typing.Union[_PHMMERQueryType, typing.Iterable[_PHMMERQueryType]], + queries: typing.Union[_M, typing.Iterable[_M]], sequences: typing.Iterable[DigitalSequence], *, cpus: int = 0, - callback: typing.Optional[typing.Callable[[_PHMMERQueryType, int], None]] = None, + callback: typing.Optional[typing.Callable[[_M, int], None]] = None, builder: typing.Optional[Builder] = None, **options, # type: object -) -> typing.Iterator["TopHits[_PHMMERQueryType]"]: +) -> typing.Iterator["TopHits[_M]"]: """Search protein sequences against a sequence database. Arguments: @@ -973,13 +979,13 @@ def phmmer( queries=queries, targets=targets, cpus=_cpus, - callback=callback, + callback=callback, # type: ignore pipeline_class=Pipeline, alphabet=alphabet, builder=_builder, **options, # type: ignore ) - return dispatcher.run() + return dispatcher.run() # type: ignore # --- jackhmmer ----------------------------------------------------------------- @@ -1040,7 +1046,7 @@ def jackhmmer( sequences: typing.Iterable[DigitalSequence], *, max_iterations: typing.Optional[int] = 5, - select_hits: typing.Optional[typing.Callable[["TopHits[Any]"], None]] = None, + select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None, checkpoints: bool = False, cpus: int = 0, callback: typing.Optional[typing.Callable[[_JACKHMMERQueryType, int], None]] = None, @@ -1159,14 +1165,14 @@ def jackhmmer( def nhmmer( - queries: typing.Union[_NHMMERQueryType, typing.Iterable[_NHMMERQueryType]], + queries: typing.Union[_N, typing.Iterable[_N]], sequences: typing.Iterable[DigitalSequence], *, cpus: int = 0, - callback: typing.Optional[typing.Callable[[_NHMMERQueryType, int], None]] = None, + callback: typing.Optional[typing.Callable[[_N, int], None]] = None, builder: typing.Optional[Builder] = None, **options, # type: object -) -> typing.Iterator["TopHits[_NHMMERQueryType]"]: +) -> typing.Iterator["TopHits[_N]"]: """Search nucleotide sequences against a sequence database. Arguments: @@ -1253,13 +1259,13 @@ def nhmmer( queries=queries, targets=targets, cpus=_cpus, - callback=callback, + callback=callback, # type: ignore pipeline_class=LongTargetsPipeline, alphabet=alphabet, builder=_builder, **options, # type: ignore ) - return dispatcher.run() + return dispatcher.run() # type: ignore # --- hmmpress --------------------------------------------------------------- diff --git a/pyhmmer/utils.py b/pyhmmer/utils.py index ebeaefc4c8cc05d0c4c3a586d6d71e794b01a2b2..0d078f89e7c4d1bdfaffdade8ee09aea093cf80b 100644 --- a/pyhmmer/utils.py +++ b/pyhmmer/utils.py @@ -75,25 +75,33 @@ class singledispatchmethod(typing.Generic[_T]): return _method @typing.overload - def register( - self, cls: typing.Type[typing.Any], method: None = ... + def register( # type: ignore + self, + cls: typing.Type[typing.Any], + method: None = ... ) -> typing.Callable[[typing.Callable[..., _T]], typing.Callable[..., _T]]: ... @typing.overload - def register( - self, cls: typing.Callable[..., _T], method: None = ... + def register( # type: ignore + self, + cls: typing.Callable[..., _T], + method: None = ... ) -> typing.Callable[..., _T]: ... @typing.overload - def register( - self, cls: typing.Type[typing.Any], method: typing.Callable[..., _T] + def register( # type: ignore + self, + cls: typing.Type[typing.Any], + method: typing.Callable[..., _T] ) -> typing.Callable[..., _T]: ... def register( - self, cls: typing.Any, method: typing.Optional[typing.Callable[..., _T]] = None + self, + cls: typing.Any, + method: typing.Optional[typing.Callable[..., _T]] = None ) -> typing.Any: """Registers a new implementation for the given class.""" return self.dispatcher.register(cls, func=method)