Skip to content
Snippets Groups Projects
Commit c59025e0 authored by Martin Larralde's avatar Martin Larralde
Browse files

Make `SequenceFile` generic on the sequence type

parent 5433c1a8
No related branches found
No related tags found
No related merge requests found
......@@ -582,7 +582,7 @@ class DigitalSequence(Sequence):
# --- Sequence block ---------------------------------------------------------
S = typing.TypeVar("S", bound=Sequence)
S = typing.TypeVar("S", TextSequence, DigitalSequence)
B = typing.TypeVar("B")
class SequenceBlock(typing.MutableSequence[S], typing.Generic[S]):
......@@ -626,7 +626,7 @@ class DigitalSequenceBlock(SequenceBlock[DigitalSequence]):
# --- Sequence File ----------------------------------------------------------
class SequenceFile(typing.ContextManager[SequenceFile], typing.Iterator[Sequence]):
class SequenceFile(typing.Generic[S], typing.ContextManager[SequenceFile[S]], typing.Iterator[S]):
_FORMATS: typing.ClassVar[typing.Dict[str, int]]
alphabet: typing.Optional[Alphabet]
name: typing.Optional[str]
......@@ -636,24 +636,35 @@ class SequenceFile(typing.ContextManager[SequenceFile], typing.Iterator[Sequence
) -> Sequence: ...
@classmethod
def parseinto(cls, seq: Sequence, buffer: BUFFER, format: str) -> Sequence: ...
@typing.overload
def __init__(
self,
self: SequenceFile[DigitalSequence],
file: typing.Union[typing.AnyStr, os.PathLike[typing.AnyStr], typing.BinaryIO],
format: typing.Optional[str] = None,
*,
ignore_gaps: bool = False,
digital: bool = False,
digital: Literal[True],
alphabet: typing.Optional[Alphabet] = None,
) -> None: ...
@typing.overload
def __init__(
self: SequenceFile[TextSequence],
file: typing.Union[typing.AnyStr, os.PathLike[typing.AnyStr], typing.BinaryIO],
format: typing.Optional[str] = None,
*,
ignore_gaps: bool = False,
digital: Literal[False] = False,
alphabet: typing.Optional[Alphabet] = None,
) -> None: ...
def __enter__(self) -> SequenceFile: ...
def __enter__(self) -> SequenceFile[S]: ...
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> bool: ...
def __iter__(self) -> SequenceFile: ...
def __next__(self) -> Sequence: ...
def __iter__(self) -> SequenceFile[S]: ...
def __next__(self) -> S: ...
def __repr__(self) -> str: ...
@property
def closed(self) -> bool: ...
......@@ -663,15 +674,15 @@ class SequenceFile(typing.ContextManager[SequenceFile], typing.Iterator[Sequence
def format(self) -> str: ...
def read(
self, skip_info: bool = False, skip_sequence: bool = False
) -> typing.Optional[Sequence]: ...
) -> typing.Optional[S]: ...
def readinto(
self, seq: Sequence, skip_info: bool = False, skip_sequence: bool = False
) -> typing.Optional[Sequence]: ...
) -> typing.Optional[S]: ...
def read_block(
self,
sequences: typing.Optional[int] = None,
residues: typing.Optional[int] = None,
) -> SequenceBlock[Sequence]: ...
) -> SequenceBlock[S]: ...
def rewind(self) -> None: ...
def close(self) -> None: ...
def guess_alphabet(self) -> typing.Optional[Alphabet]: ...
......
......@@ -291,7 +291,7 @@ class _BaseWorker(typing.Generic[_Q, _T, _R], threading.Thread):
class _SEARCHWorker(
_BaseWorker[
_SEARCHQueryType,
typing.Union[DigitalSequenceBlock, SequenceFile],
typing.Union[DigitalSequenceBlock, "SequenceFile[DigitalSequence]"],
"TopHits[_SEARCHQueryType]",
]
):
......@@ -311,7 +311,7 @@ class _SEARCHWorker(
class _PHMMERWorker(
_BaseWorker[
_PHMMERQueryType,
typing.Union[DigitalSequenceBlock, SequenceFile],
typing.Union[DigitalSequenceBlock, "SequenceFile[DigitalSequence]"],
"TopHits[_PHMMERQueryType]",
]
):
......@@ -425,7 +425,7 @@ class _JACKHMMERWorker(
class _NHMMERWorker(
_BaseWorker[
_NHMMERQueryType,
typing.Union[DigitalSequenceBlock, SequenceFile],
typing.Union[DigitalSequenceBlock, "SequenceFile[DigitalSequence]"],
"TopHits[_NHMMERQueryType]",
]
):
......@@ -594,7 +594,7 @@ class _BaseDispatcher(typing.Generic[_Q, _T, _R], abc.ABC):
class _SEARCHDispatcher(
_BaseDispatcher[
_SEARCHQueryType,
typing.Union[DigitalSequenceBlock, SequenceFile],
typing.Union[DigitalSequenceBlock, "SequenceFile[DigitalSequence]"],
"TopHits[_SEARCHQueryType]",
]
):
......@@ -628,7 +628,7 @@ class _SEARCHDispatcher(
class _PHMMERDispatcher(
_BaseDispatcher[
_PHMMERQueryType,
typing.Union[DigitalSequenceBlock, SequenceFile],
typing.Union[DigitalSequenceBlock, "SequenceFile[DigitalSequence]"],
"TopHits[_PHMMERQueryType]",
]
):
......@@ -724,14 +724,14 @@ class _JACKHMMERDispatcher(
class _NHMMERDispatcher(
_BaseDispatcher[
_NHMMERQueryType,
typing.Union[DigitalSequenceBlock, SequenceFile],
typing.Union[DigitalSequenceBlock, "SequenceFile[DigitalSequence]"],
"TopHits[_NHMMERQueryType]",
]
):
def __init__(
self,
queries: typing.Iterable[_NHMMERQueryType],
targets: typing.Union[DigitalSequenceBlock, SequenceFile],
targets: typing.Union[DigitalSequenceBlock, "SequenceFile[DigitalSequence]"],
cpus: int = 0,
callback: typing.Optional[
typing.Callable[[_NHMMERQueryType, int], None]
......@@ -813,7 +813,7 @@ class _SCANDispatcher(
def hmmsearch(
queries: typing.Union[_P, typing.Iterable[_P]],
sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
sequences: typing.Iterable[DigitalSequence],
*,
cpus: int = 0,
callback: typing.Optional[typing.Callable[[_P, int], None]] = None,
......@@ -895,7 +895,7 @@ def hmmsearch(
raise ValueError("expected digital mode `SequenceFile` for targets")
assert sequences.alphabet is not None
alphabet = alphabet or sequences.alphabet
targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequences
targets: typing.Union["SequenceFile[DigitalSequence]", DigitalSequenceBlock] = sequences
elif isinstance(sequences, DigitalSequenceBlock):
alphabet = alphabet or sequences.alphabet
targets = sequences
......@@ -927,7 +927,7 @@ def hmmsearch(
def phmmer(
queries: typing.Union[_M, typing.Iterable[_M]],
sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
sequences: typing.Iterable[DigitalSequence],
*,
cpus: int = 0,
callback: typing.Optional[typing.Callable[[_M, int], None]] = None,
......@@ -987,14 +987,13 @@ def phmmer(
queries = (queries,)
if isinstance(sequences, SequenceFile):
sequence_file: SequenceFile = sequences
if sequence_file.name is None:
if sequences.name is None:
raise ValueError("expected named `SequenceFile` for targets")
if not sequence_file.digital:
if not sequences.digital:
raise ValueError("expected digital mode `SequenceFile` for targets")
assert sequence_file.alphabet is not None
alphabet = alphabet or sequence_file.alphabet
targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequence_file
assert sequences.alphabet is not None
alphabet = alphabet or sequences.alphabet
targets: typing.Union["SequenceFile[DigitalSequence]", DigitalSequenceBlock] = sequences
elif isinstance(sequences, DigitalSequenceBlock):
alphabet = alphabet or sequences.alphabet
targets = sequences
......@@ -1027,7 +1026,7 @@ def phmmer(
@typing.overload
def jackhmmer(
queries: typing.Union[_JACKHMMERQueryType, typing.Iterable[_JACKHMMERQueryType]],
sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
sequences: typing.Iterable[DigitalSequence],
*,
max_iterations: typing.Optional[int] = 5,
select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None,
......@@ -1043,7 +1042,7 @@ def jackhmmer(
@typing.overload
def jackhmmer(
queries: typing.Union[_JACKHMMERQueryType, typing.Iterable[_JACKHMMERQueryType]],
sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
sequences: typing.Iterable[DigitalSequence],
*,
max_iterations: typing.Optional[int] = 5,
select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None,
......@@ -1059,7 +1058,7 @@ def jackhmmer(
@typing.overload
def jackhmmer(
queries: typing.Union[_JACKHMMERQueryType, typing.Iterable[_JACKHMMERQueryType]],
sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
sequences: typing.Union[typing.Iterable[DigitalSequence], "SequenceFile[DigitalSequence]"],
*,
max_iterations: typing.Optional[int] = 5,
select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None,
......@@ -1076,7 +1075,7 @@ def jackhmmer(
def jackhmmer(
queries: typing.Union[_JACKHMMERQueryType, typing.Iterable[_JACKHMMERQueryType]],
sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
sequences: typing.Union[typing.Iterable[DigitalSequence], "SequenceFile[DigitalSequence]"],
*,
max_iterations: typing.Optional[int] = 5,
select_hits: typing.Optional[typing.Callable[["TopHits[_JACKHMMERQueryType]"], None]] = None,
......@@ -1209,7 +1208,7 @@ def jackhmmer(
def nhmmer(
queries: typing.Union[_N, typing.Iterable[_N]],
sequences: typing.Union[typing.Iterable[DigitalSequence], SequenceFile],
sequences: typing.Iterable[DigitalSequence],
*,
cpus: int = 0,
callback: typing.Optional[typing.Callable[[_N, int], None]] = None,
......@@ -1279,7 +1278,7 @@ def nhmmer(
raise ValueError("expected digital mode `SequenceFile` for targets")
assert sequences.alphabet is not None
alphabet = alphabet or sequences.alphabet
targets: typing.Union[SequenceFile, DigitalSequenceBlock] = sequences
targets: typing.Union["SequenceFile[DigitalSequence]", DigitalSequenceBlock] = sequences
elif isinstance(sequences, DigitalSequenceBlock):
alphabet = alphabet or sequences.alphabet
targets = sequences
......@@ -1625,7 +1624,7 @@ if __name__ == "__main__":
sequences = sequences.read_block() # type: ignore
# load the query sequences iteratively
with SequenceFile(args.seqfile, digital=True, alphabet=alphabet) as queries:
hits_list = phmmer(queries, sequences, cpus=args.jobs) # type: ignore
hits_list = phmmer(queries, sequences, cpus=args.jobs)
for hits in hits_list:
for hit in hits:
if hit.reported:
......@@ -1646,7 +1645,7 @@ if __name__ == "__main__":
def open_query_file(
queryfile: typing.Union["os.PathLike[str]", typing.BinaryIO],
alphabet: Alphabet,
) -> typing.Iterator[typing.Union[SequenceFile, HMMFile]]:
) -> typing.Iterator[typing.Union["SequenceFile[DigitalSequence]", HMMFile]]:
"""Open either a sequence file or an HMM file."""
try:
yield SequenceFile(queryfile, digital=True, alphabet=alphabet)
......@@ -1691,7 +1690,7 @@ if __name__ == "__main__":
# at the moment `LongTargetsPipeline` only support block targets, not files
with SequenceFile(args.seqdb, digital=True) as seqfile:
with SequenceFile(args.seqfile, digital=True) as queryfile:
hits_list = nhmmer(queryfile, seqfile, cpus=args.jobs) # type: ignore
hits_list = nhmmer(queryfile, seqfile, cpus=args.jobs)
for hits in hits_list:
for hit in hits:
if hit.reported:
......@@ -1757,7 +1756,7 @@ if __name__ == "__main__":
def _hmmalign(args: argparse.Namespace) -> int:
try:
with SequenceFile(args.seqfile, args.informat, digital=True) as seqfile:
sequences: typing.List[DigitalSequence] = list(seqfile) # type: ignore
sequences: typing.List[DigitalSequence] = list(seqfile)
except EOFError as err:
print(err, file=sys.stderr)
return 1
......
......@@ -724,18 +724,18 @@ class Pipeline(object):
def search_hmm(
self,
query: typing.Union[HMM, Profile, OptimizedProfile],
sequences: typing.Union[DigitalSequenceBlock, SequenceFile],
sequences: typing.Union[DigitalSequenceBlock, SequenceFile[DigitalSequence]],
) -> TopHits[typing.Union[HMM, Profile, OptimizedProfile]]: ...
def search_msa(
self,
query: DigitalMSA,
sequences: typing.Union[DigitalSequenceBlock, SequenceFile],
sequences: typing.Union[DigitalSequenceBlock, SequenceFile[DigitalSequence]],
builder: typing.Optional[Builder] = None,
) -> TopHits[DigitalMSA]: ...
def search_seq(
self,
query: DigitalSequence,
sequences: typing.Union[DigitalSequenceBlock, SequenceFile],
sequences: typing.Union[DigitalSequenceBlock, SequenceFile[DigitalSequence]],
builder: typing.Optional[Builder] = None,
) -> TopHits[DigitalSequence]: ...
def scan_seq(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment