Commits (29)
......@@ -5,7 +5,17 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.5.3...master
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.5.4...master
## [v0.5.4] - 2021-02-28
[v0.5.4]: https://git.embl.de/grp-zeller/GECCO/compare/v0.5.3...v0.5.4
### Changed
- Replaced `verboselogs`, `coloredlogs` and `better-exceptions` with `rich`.
### Removed
- `tqdm` training dependency.
### Added
- `gecco annotate` command to produce a feature table from a genomic file.
- `gecco embed` to embed BGCs into non-BGC regions using feature tables.
## [v0.5.3] - 2021-02-21
[v0.5.3]: https://git.embl.de/grp-zeller/GECCO/compare/v0.5.2...v0.5.3
......
# coding: utf-8
import contextlib
import sys
import typing
from typing import Optional, List, TextIO
if typing.TYPE_CHECKING:
import logging
def main(
argv: Optional[List[str]] = None,
stream: Optional[TextIO] = None,
logger: Optional["logging.Logger"] = None,
) -> int:
def main(argv: Optional[List[str]] = None, stream: Optional[TextIO] = None) -> int:
from .commands._main import Main
_main = Main(argv, stream, logger)
return _main()
with contextlib.ExitStack() as ctx:
return Main(argv, stream).execute(ctx)
......@@ -9,115 +9,25 @@ import warnings
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Type, TextIO
import numpy
import verboselogs
from .._meta import classproperty
if typing.TYPE_CHECKING:
_S = typing.TypeVar("_S")
_T = typing.TypeVar("_T")
_F = typing.TypeVar("_F", bound=Callable[..., "_T"])
class BraceAdapter(logging.LoggerAdapter, verboselogs.VerboseLogger):
"""A logging adapter for `VerboseLogger` to use new-style formatting.
"""
class Message(object):
def __init__(self, fmt: object, args: Iterable[object]):
self.fmt = str(fmt)
self.args = args
def __str__(self) -> str:
return self.fmt.format(*self.args)
def __init__(
self, logger: logging.Logger, extra: Optional[Dict[str, object]] = None
) -> None:
super(BraceAdapter, self).__init__(logger, extra or {})
@property
def level(self) -> int:
return self.logger.level
def log(self, level: int, msg: str, *args: object, **kwargs: Any) -> None:
if self.isEnabledFor(level):
msg, kw = self.process(msg, kwargs)
self.logger._log(level, self.Message(msg, args), (), **kw)
def notice(self, msg: str, *args: object, **kwargs: Any) -> None:
if self.isEnabledFor(verboselogs.NOTICE):
msg, kw = self.process(msg, kwargs)
self.logger._log(verboselogs.NOTICE, self.Message(msg, args), **kw)
def spam(self, msg: str, *args: object, **kwargs: Any) -> None:
if self.isEnabledFor(verboselogs.SPAM):
msg, kw = self.process(msg, kwargs)
self.logger._log(verboselogs.SPAM, self.Message(msg, args), **kw)
def verbose(self, msg: str, *args: object, **kwargs: Any) -> None:
if self.isEnabledFor(verboselogs.VERBOSE):
msg, kw = self.process(msg, kwargs)
self.logger._log(verboselogs.VERBOSE, self.Message(msg, args), **kw)
def success(self, msg: str, *args: object, **kwargs: Any) -> None:
if self.isEnabledFor(verboselogs.SUCCESS):
msg, kw = self.process(msg, kwargs)
self.logger._log(verboselogs.SUCCESS, self.Message(msg, args), **kw)
def wrap_warnings(logger: logging.Logger) -> Callable[["_F"], "_F"]:
"""Have the function patch `warnings.showwarning` with the given logger.
Arguments:
logger (~logging.Logger): the logger to wrap warnings with when
the decorated function is called.
Returns:
`function`: a decorator function that will wrap a callable and
redirect any warning raised by that callable to the given logger.
Example:
>>> logger = logging.getLogger()
>>> @wrap_warnings(logger)
... def divide_by_zero(x):
... return numpy.array(x) / 0
ShowWarning = typing.Callable[
[str, Type[Warning], str, int, Optional[TextIO], Optional[str]],
None
]
@contextlib.contextmanager
def patch_showwarnings(new_showwarning: "ShowWarning") -> Iterator[None]:
"""Make a context patching `warnings.showwarning` with the given function.
"""
class _WarningsWrapper(object):
def __init__(self, logger: logging.Logger, func: Callable[..., "_T"]):
self.logger = logger
self.func = func
functools.update_wrapper(self, func)
def showwarning(
self,
message: str,
category: Type[Warning],
filename: str,
lineno: int,
file: Optional[TextIO] = None,
line: Optional[str] = None,
) -> None:
for line in filter(str.strip, str(message).splitlines()):
self.logger.warning(line.strip())
def __call__(self, *args: Any, **kwargs: Any) -> "_T":
old_showwarning = warnings.showwarning
warnings.showwarning = self.showwarning
try:
return self.func(*args, **kwargs)
finally:
warnings.showwarning = old_showwarning
def __getattr__(self, name: Any) -> Any:
return getattr(self.func, name)
def decorator(func: Callable[..., "_T"]) -> Callable[..., "_T"]:
return _WarningsWrapper(logger, func)
return decorator # type: ignore
old_showwarning = warnings.showwarning
try:
warnings.showwarning = new_showwarning
yield
finally:
warnings.showwarning = old_showwarning
@contextlib.contextmanager
......@@ -171,3 +81,13 @@ def guess_sequences_format(path: str) -> Optional[str]:
return "genbank"
else:
return None
def in_context(func):
@functools.wraps(func)
def newfunc(*args, **kwargs):
with contextlib.ExitStack() as ctx:
return func(*args, ctx, **kwargs)
return newfunc
# coding: utf-8
import abc
import contextlib
import datetime
import logging
import os
import socket
import sys
import textwrap
import typing
from typing import Any, ClassVar, Optional, List, Mapping, Dict, TextIO
from typing import Any, ClassVar, Callable, Optional, List, Mapping, Dict, TextIO, Type
import coloredlogs
import docopt
import verboselogs
import rich.console
import rich.progress
import rich.logging
from ... import __version__, __name__ as __progname__
from .._utils import BraceAdapter
_T = typing.TypeVar("_T")
class InvalidArgument(ValueError):
"""An error to mark an invalid value was passed to a CLI flag.
"""
class CommandExit(Exception):
"""An error to request immediate exit from a function.
"""
def __init__(self, code):
self.code = code
class Command(metaclass=abc.ABCMeta):
"""An abstract base class for ``gecco`` subcommands.
......@@ -20,11 +37,24 @@ class Command(metaclass=abc.ABCMeta):
# -- Abstract methods ----------------------------------------------------
doc: ClassVar[str] = NotImplemented
summary: ClassVar[str] = NotImplemented
@abc.abstractmethod
def __call__(self) -> int:
def execute(self, ctx: contextlib.ExitStack) -> int:
"""Execute the command.
Returns:
`int`: The exit code for the command, with 0 on success, and any
other number on error.
"""
return NotImplemented # type: ignore
@classmethod
@abc.abstractmethod
def doc(cls, fast: bool = False) -> str:
"""Get the help message for the command.
"""
return NotImplemented # type: ignore
# -- Concrete methods ----------------------------------------------------
......@@ -36,11 +66,9 @@ class Command(metaclass=abc.ABCMeta):
self,
argv: Optional[List[str]] = None,
stream: Optional[TextIO] = None,
logger: Optional[logging.Logger] = None,
options: Optional[Mapping[str, Any]] = None,
config: Optional[Dict[Any, Any]] = None,
) -> None:
self._stream: Optional[TextIO] = stream
self.argv = argv
self.stream: TextIO = stream or sys.stderr
......@@ -48,49 +76,122 @@ class Command(metaclass=abc.ABCMeta):
self.pool = None
self.config = config
self._hostname = socket.gethostname()
self._pid = os.getpid()
# Parse command line arguments
try:
self.args = docopt.docopt(
textwrap.dedent(self.doc).lstrip(),
textwrap.dedent(self.doc(fast=True)).lstrip(),
help=False,
argv=argv,
version=self._version,
options_first=self._options_first,
)
loglevel = self._get_log_level()
self.verbose = self.args.get("--verbose", 0)
self.quiet = self.args.get("--quiet", 0)
except docopt.DocoptExit as de:
self.args = de
loglevel = None
# Create a new colored logger if needed
if logger is None:
logger = verboselogs.VerboseLogger(__progname__)
loglevel = (loglevel or "INFO").upper()
coloredlogs.install(
logger=logger,
level=int(loglevel) if loglevel.isdigit() else loglevel,
stream=self.stream,
)
# Use a loggin adapter to use new-style formatting
self.logger = BraceAdapter(logger)
def _check(self) -> Optional[int]:
self.level = 0
self.quiet = 0
self.progress = rich.progress.Progress(
rich.progress.SpinnerColumn(finished_text="[green]:heavy_check_mark:[/]"),
"[progress.description]{task.description}",
rich.progress.BarColumn(bar_width=60),
"[progress.completed]{task.completed}/{task.total}",
"[progress.completed]{task.fields[unit]}",
"[progress.percentage]{task.percentage:>3.0f}%",
rich.progress.TimeElapsedColumn(),
rich.progress.TimeRemainingColumn(),
console=rich.console.Console(file=self.stream),
disable=self.quiet > 0,
)
self.console = self.progress.console
def _check(self) -> None:
# Assert CLI arguments were parsed Successfully
if isinstance(self.args, docopt.DocoptExit):
print(self.args, file=self.stream)
return 1
# Display help if needed
elif self.args["--help"]:
print(textwrap.dedent(self.doc).lstrip())
return 0
else:
return None
self.console.print(self.args)#, file=self.stream)
raise CommandExit(1)
return None
def _get_log_level(self) -> Optional[str]:
if self.args.get("--verbose"):
return "VERBOSE" if self.args.get("--verbose") == 1 else "DEBUG"
elif self.args.get("--quiet"):
return "ERROR"
def _check_flag(
self,
name: str,
convert: Optional[Callable[[str], _T]] = None,
check: Optional[Callable[[_T], bool]] = None,
message: Optional[str] = None,
hint: Optional[str] = None,
) -> _T:
_convert = (lambda x: x) if convert is None else convert
_check = (lambda x: True) if check is None else check
try:
value = _convert(self.args[name])
if not _check(value):
raise ValueError(self.args[name])
except Exception as err:
if hint is None:
self.error(f"Invalid value for argument [purple]{name}[/]:", repr(self.args[name]))
else:
self.error(f"Invalid value for argument [purple]{name}[/]:", repr(self.args[name]), f"(expected {hint})")
raise InvalidArgument(self.args[name]) from err
else:
return typing.cast(Optional[str], self.args.get("--log"))
return value
# -- Logging methods -----------------------------------------------------
def error(self, message, *args, level=0):
if self.quiet <= 2 and level <= self.verbose:
self.console.print(
*self._logprefix(),
"[bold red]FAIL[/]",
message,
*args,
)
def info(self, verb, *args, level=1):
if self.quiet == 0 and level <= self.verbose:
self.console.print(
*self._logprefix(),
f"[bold blue]INFO[/]",
verb,
*args,
)
def success(self, verb, *args, level=1):
if self.quiet == 0 and level <= self.verbose:
self.console.print(
*self._logprefix(),
f"[bold green] OK[/]",
verb,
*args,
)
def warn(self, verb, *args, level=0):
if self.quiet <= 1 and level <= self.verbose:
self.console.print(
*self._logprefix(),
"[bold yellow]WARN[/]",
verb,
*args
)
def _logprefix(self):
return [
f"[dim cyan]{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}[/]",
f"[dim purple]{self._hostname}[/]",
f"[dim]{__progname__}[[default dim]{self._pid}[/]][/]",
]
def _showwarnings(
self,
message: str,
category: Type[Warning],
filename: str,
lineno: int,
file: Optional[TextIO] = None,
line: Optional[str] = None,
) -> None:
for line in filter(str.strip, str(message).splitlines()):
self.warn(line.strip())
"""Implementation of the main ``gecco`` command.
"""
import contextlib
import signal
import sys
import textwrap
import typing
import warnings
from typing import Mapping, Optional, Type
import better_exceptions
import docopt
import operator
import pkg_resources
import rich.traceback
from ... import __version__
from .._utils import classproperty, wrap_warnings
from .._utils import in_context, patch_showwarnings
from . import __name__ as __parent__
from ._base import Command
from ._base import Command, CommandExit, InvalidArgument
class Main(Command):
"""The *main* command launched before processing subcommands.
"""
@classmethod
def _get_subcommand_names(cls) -> Mapping[str, Type[Command]]:
return [cmd.name for cmd in pkg_resources.iter_entry_points(__parent__)]
@classmethod
def _get_subcommands(cls) -> Mapping[str, Type[Command]]:
commands = {}
......@@ -32,25 +39,34 @@ class Main(Command):
return commands
@classmethod
def _get_subcommand(cls, name: str) -> Optional[Type[Command]]:
return cls._get_subcommands().get(name)
@classproperty
def doc(cls) -> str: # type: ignore
commands = (
" {:12}{}".format(name, typing.cast(Command, cmd).summary)
for name, cmd in sorted(
cls._get_subcommands().items(), key=operator.itemgetter(0)
def _get_subcommand_by_name(cls, name: str) -> Optional[Type[Command]]:
for cmd in pkg_resources.iter_entry_points(__parent__):
if cmd.name == name:
return cmd.load()
return None
# --
@classmethod
def doc(cls, fast=False): # noqa: D102
if fast:
commands = (f" {cmd}" for cmd in cls._get_subcommand_names())
else:
commands = (
" {:12}{}".format(name, typing.cast(Command, cmd).summary)
for name, cmd in sorted(
cls._get_subcommands().items(), key=operator.itemgetter(0)
)
)
)
return (
textwrap.dedent(
"""
gecco - Gene Cluster Prediction with Conditional Random Fields
Usage:
gecco [-v | -vv | -q | -l <level>] [options] (-h | --help) [<cmd>]
gecco [-v | -vv | -q | -l <level>] [options] <cmd> [<args>...]
gecco [-v | -vv | -q | -qq] <cmd> [<args>...]
gecco --version
gecco --help [<cmd>]
Commands:
{commands}
......@@ -59,15 +75,12 @@ class Main(Command):
-h, --help show the message for ``gecco`` or
for a given subcommand.
-q, --quiet silence any output other than errors
(corresponds to the log level ERROR).
-v, --verbose control the verbosity of the output
(corresponds to the log level DEBUG).
(-qq silences everything).
-v, --verbose increase verbosity (-v is minimal,
-vv is verbose, and -vvv shows
debug information).
-V, --version show the program version and exit.
Parameters - Debug:
--traceback display full traceback on error.
-l <level>, --log <level> the level of log message to display.
[available: DEBUG, INFO, WARNING, ERROR]
"""
)
.lstrip()
......@@ -76,67 +89,68 @@ class Main(Command):
_options_first = True
def __call__(self) -> int:
# Assert CLI arguments were parsed successfully
if isinstance(self.args, docopt.DocoptExit):
print(self.args, file=self.stream)
return 1
# Get the subcommand class
subcmd_cls = self._get_subcommand(self.args["<cmd>"])
# Exit if no known command was found
if self.args["<cmd>"] is not None and subcmd_cls is None:
self.logger.error("Unknown subcommand: {!r}", self.args["<cmd>"])
return 1
# Print a help message if asked for
if (
self.args["--help"]
or "-h" in self.args["<args>"]
or "--help" in self.args["<args>"]
):
subcmd = typing.cast(Type[Command], self._get_subcommand("help"))(
argv=["help"] + [self.args["<cmd>"]],
stream=self._stream,
logger=self.logger,
options=self.args,
config=self.config,
)
# --
# Print version information
elif self.args["--version"]:
print("gecco", __version__)
return 0
def execute(self, ctx: contextlib.ExitStack) -> int:
# Run the app, elegantly catching any interrupts or exceptions
try:
# check arguments and enter context
self._check()
ctx.enter_context(patch_showwarnings(self._showwarnings))
# Initialize the command if is valid
else:
subcmd = wrap_warnings(self.logger)( # type: ignore
typing.cast(Type[Command], subcmd_cls)(
# Get the subcommand class
subcmd_name = self.args["<cmd>"]
try:
subcmd_cls = self._get_subcommand_by_name(subcmd_name)
except pkg_resources.DistributionNotFound as dnf:
self.error("The", repr(subcmd_name), "subcommand requires package", dnf.req)
return 1
# exit if no known command was found
if subcmd_name is not None and subcmd_cls is None:
self.error("Unknown subcommand", repr(subcmd_name))
return 1
# if a help message was required, delegate to the `gecco help` command
if (
self.args["--help"]
or "-h" in self.args["<args>"]
or "--help" in self.args["<args>"]
):
subcmd = typing.cast(Type[Command], self._get_subcommand_by_name("help"))(
argv=["help"] + [subcmd_name],
stream=self._stream,
options=self.args,
config=self.config,
)
# print version information if `--version` in flags
elif self.args["--version"]:
self.console.print("gecco", __version__)
return 0
# initialize the command if is valid
else:
subcmd = typing.cast(Type[Command], subcmd_cls)(
argv=[self.args["<cmd>"]] + self.args["<args>"],
stream=self._stream,
logger=self.logger,
options=self.args,
config=self.config,
)
)
# Run the app, elegantly catching any interrupts or exceptions
try:
exitcode = subcmd._check()
if exitcode is None:
exitcode = subcmd()
subcmd.verbose = self.verbose
subcmd.quiet = self.quiet
# run the subcommand
return subcmd.execute(ctx)
except CommandExit as sysexit:
return sysexit.code
except KeyboardInterrupt:
self.logger.error("Interrupted")
return 2
self.error("Interrupted")
return -signal.SIGINT
except Exception as e:
self.logger.critical("{}", e)
if self.args["--traceback"]:
print(
better_exceptions.format_exception(type(e), e, e.__traceback__),
file=sys.stderr,
)
self.error(
"An unexpected error occurred. Consider opening"
" a new issue on the bug tracker"
" (https://github.com/zellerlab/GECCO/issues/new) if"
" it persists, including the traceback below:"
)
traceback = rich.traceback.Traceback.from_exception(type(e), e, e.__traceback__)
self.console.print(traceback)
# return errno if exception has any
return typing.cast(int, getattr(e, "errno", 1))
else:
return exitcode
"""Implementation of the ``gecco annotate`` subcommand.
"""
import contextlib
import errno
import glob
import itertools
import logging
import multiprocessing
import operator
import os
import pickle
import tempfile
import typing
import signal
from typing import Any, Dict, Union, Optional, List, TextIO, Mapping
import numpy
import rich.emoji
import rich.progress
from Bio import SeqIO
from ._base import Command, CommandExit, InvalidArgument
from .._utils import guess_sequences_format, in_context, patch_showwarnings
from ...crf import ClusterCRF
from ...hmmer import PyHMMER, HMM, embedded_hmms
from ...model import FeatureTable, ClusterTable, ProductType
from ...orf import PyrodigalFinder
from ...types import TypeClassifier
from ...refine import ClusterRefiner
class Annotate(Command): # noqa: D101
summary = "annotate protein features of one or several contigs."
@classmethod
def doc(cls, fast=False): # noqa: D102
return f"""
gecco annotate - {cls.summary}
Usage:
gecco annotate --genome <file> [--hmm <hmm>]... [options]
Arguments:
-g <file>, --genome <file> a genomic file containing one or more
sequences to use as input. Must be in
one of the sequences format supported
by Biopython.
Parameters:
-f <fmt>, --format <fmt> the format of the input file, as a
Biopython format string. GECCO is able
to recognize FASTA and GenBank files
automatically if this is not given.
-o <out>, --output <out> the file where to write the feature
table. [default: features.tsv]
-j <jobs>, --jobs <jobs> the number of CPUs to use for
multithreading. Use 0 to use all of the
available CPUs. [default: 0]
Parameters - Domain Annotation:
-e <e>, --e-filter <e> the e-value cutoff for protein domains
to be included. [default: 1e-5]
Parameters - Debug:
--hmm <hmm> the path to one or more alternative
HMM file to use (in HMMER format).
"""
def _check(self) -> typing.Optional[int]:
super()._check()
try:
self.e_filter = self._check_flag("--e-filter", float, lambda x: 0 <= x <= 1, hint="real number between 0 and 1")
self.jobs = self._check_flag("--jobs", int, lambda x: x >= 0, hint="positive or null integer")
self.format = self._check_flag("--format")
self.genome = self._check_flag("--genome")
self.hmm = self._check_flag("--hmm")
self.output = self._check_flag("--output")
except InvalidArgument:
raise CommandExit(1)
def _custom_hmms(self):
for path in self.hmm:
base = os.path.basename(path)
if base.endswith(".gz"):
base, _ = os.path.splitext(base)
base, _ = os.path.splitext(base)
yield HMM(
id=base,
version="?",
url="?",
path=path,
relabel_with=r"s/([^\.]*)(\..*)?/\1/"
)
# ---
def _load_sequences(self):
if self.format is not None:
format = self.format
self.info("Using", "user-provided sequence format", repr(format), level=2)
else:
self.info("Detecting", "sequence format from file contents", level=2)
format = guess_sequences_format(self.genome)
self.success("Detected", "format of input as", repr(format), level=2)
self.info("Loading", "sequences from genomic file", repr(self.genome), level=1)
try:
sequences = list(SeqIO.parse(self.genome, format))
except FileNotFoundError as err:
self.error("Could not find input file:", repr(self.genome))
raise CommandExit(e.errno) from err
except Exception as err:
self.error("Failed to load sequences:", err)
raise CommandExit(getattr(err, "errno", 1)) from err
else:
self.success("Found", len(sequences), "sequences", level=1)
return sequences
def _extract_genes(self, sequences):
self.info("Extracting", "genes from input sequences", level=1)
orf_finder = PyrodigalFinder(metagenome=True, cpus=self.jobs)
unit = "contigs" if len(sequences) > 1 else "contig"
task = self.progress.add_task(description="ORFs finding", total=len(sequences), unit=unit)
def callback(record, found, total):
self.success("Found", found, "genes in record", repr(record.id), level=2)
self.progress.update(task, advance=1)
return list(orf_finder.find_genes(sequences, progress=callback))
def _annotate_domains(self, genes):
self.info("Running", "HMMER domain annotation", level=1)
# Run all HMMs over ORFs to annotate with protein domains
hmms = list(self._custom_hmms() if self.hmm else embedded_hmms())
task = self.progress.add_task(description=f"HMM annotation", unit="HMMs", total=len(hmms))
for hmm in self.progress.track(hmms, task_id=task):
task = self.progress.add_task(description=f"{hmm.id} v{hmm.version}", total=1, unit="domains")
callback = lambda n, total: self.progress.update(task, advance=1, total=total)
self.info("Starting", f"annotation with [bold blue]{hmm.id} v{hmm.version}[/]", level=2)
features = PyHMMER(hmm, self.jobs).run(genes, progress=callback)
self.success("Finished", f"annotation with [bold blue]{hmm.id} v{hmm.version}[/]", level=2)
self.progress.update(task_id=task, visible=False)
# Count number of annotated domains
count = sum(1 for gene in genes for domain in gene.protein.domains)
self.success("Found", count, "domains across all proteins", level=1)
# Filter i-evalue
self.info("Filtering", "results with e-value under", self.e_filter, level=1)
key = lambda d: d.i_evalue < self.e_filter
for gene in genes:
gene.protein.domains = list(filter(key, gene.protein.domains))
count = sum(1 for gene in genes for domain in gene.protein.domains)
self.info("Using", "remaining", count, "domains", level=2)
# Sort genes
self.info("Sorting", "genes by coordinates", level=2)
genes.sort(key=lambda g: (g.source.id, g.start, g.end))
for gene in genes:
gene.protein.domains.sort(key=operator.attrgetter("start", "end"))
return genes
def _write_feature_table(self, genes):
self.info("Writing", "feature table to", repr(self.output), level=1)
with open(self.output, "w") as f:
FeatureTable.from_genes(genes).dump(f)
# ---
def execute(self, ctx: contextlib.ExitStack) -> int: # noqa: D102
try:
# check the CLI arguments were fine and enter context
self._check()
ctx.enter_context(self.progress)
ctx.enter_context(patch_showwarnings(self._showwarnings))
# load sequences and extract genes
sequences = self._load_sequences()
genes = self._extract_genes(sequences)
if genes:
self.success("Found", "a total of", len(genes), "genes", level=1)
else:
self.warn("No genes were found")
return 0
# annotate domains and write results
genes = self._annotate_domains(genes)
self._write_feature_table(genes)
ndoms = sum(1 for gene in genes for domain in gene.protein.domains)
# report number of proteins found
if ndoms:
self.success("Found", ndoms, "protein domains", level=0)
else:
self.warn("No protein domains were found")
except CommandExit as cexit:
return cexit.code
except KeyboardInterrupt:
self.error("Interrupted")
return -signal.SIGINT
else:
return 0
"""Implementation of the ``gecco cv`` subcommand.
"""
import contextlib
import copy
import functools
import itertools
......@@ -9,12 +10,13 @@ import operator
import multiprocessing
import random
import typing
from typing import List
from typing import Any, Dict, Union, Optional, List, TextIO, Mapping
import tqdm
import rich.progress
import sklearn.model_selection
from ._base import Command
from ._base import Command, CommandExit, InvalidArgument
from .._utils import in_context, patch_showwarnings
from ...model import ClusterTable, FeatureTable, ProductType
from ...crf import ClusterCRF
from ...crf.cv import LeaveOneGroupOut
......@@ -23,148 +25,132 @@ from ...crf.cv import LeaveOneGroupOut
class Cv(Command): # noqa: D101
summary = "perform cross validation on a training set."
doc = f"""
gecco cv - {summary}
Usage:
gecco cv (-h | --help)
gecco cv kfold -f <table> [-c <data>] [options]
gecco cv loto -f <table> -c <data> [options]
Arguments:
-f <data>, --features <table> a domain annotation table, used to
labeled as BGCs and non-BGCs.
-c <data>, --clusters <table> a cluster annotation table, use to
stratify clusters by type in LOTO
mode.
Parameters:
-o <out>, --output <out> the name of the output cross-validation
table. [default: cv.tsv]
-j <jobs>, --jobs <jobs> the number of CPUs to use for
multithreading. Use 0 to use all of the
available CPUs. [default: 0]
Parameters - Domain Annotation:
-e <e>, --e-filter <e> the e-value cutoff for domains to
be included [default: 1e-5]
Parameters - Training:
--c1 <C1> parameter for L1 regularisation.
[default: 0.15]
--c2 <C2> parameter for L2 regularisation.
[default: 0.15]
--feature-type <type> how features should be extracted
(single, overlap, or group).
[default: group]
--overlap <N> how much overlap to consider if
features overlap. [default: 2]
--splits <N> number of folds for cross-validation
(if running `kfold`). [default: 10]
--select <N> fraction of most significant features
to select from the training data.
--shuffle enable shuffling of stratified rows.
"""
@classmethod
def doc(cls, fast=False): # noqa: D102
return f"""
gecco cv - {cls.summary}
Usage:
gecco cv kfold -f <table> [-c <data>] [options]
gecco cv loto -f <table> -c <data> [options]
Arguments:
-f <data>, --features <table> a domain annotation table, used to
labeled as BGCs and non-BGCs.
-c <data>, --clusters <table> a cluster annotation table, used
to stratify clusters by type in
LOTO mode.
Parameters:
-o <out>, --output <out> the name of the output file where
the cross-validation table will be
written. [default: cv.tsv]
-j <jobs>, --jobs <jobs> the number of CPUs to use for
multithreading. Use 0 to use all
the available CPUs. [default: 0]
Parameters - Domain Annotation:
-e <e>, --e-filter <e> the e-value cutoff for domains to
be included [default: 1e-5]
Parameters - Training:
--c1 <C1> parameter for L1 regularisation.
[default: 0.15]
--c2 <C2> parameter for L2 regularisation.
[default: 0.15]
--feature-type <type> how features should be extracted
(single, overlap, or group).
[default: group]
--overlap <N> how much overlap to consider if
features overlap. [default: 2]
--splits <N> number of folds for cross-validation
(if running `kfold`). [default: 10]
--select <N> fraction of most significant features
to select from the training data.
--shuffle enable shuffling of stratified rows.
"""
def _check(self) -> typing.Optional[int]:
retcode = super()._check()
if retcode is not None:
return retcode
# Check the inputs exist
for input_ in filter(None, (self.args["--features"], self.args["--clusters"])):
if not os.path.exists(input_):
self.logger.error("could not locate input file: {!r}", input_)
return 1
# Check the `--feature-type`
type_ = self.args["--feature-type"]
if type_ not in {"single", "overlap", "group"}:
self.logger.error("Invalid value for `--feature-type`: {}", type_)
return 1
# Check value of numeric arguments
self.args["--overlap"] = int(self.args["--overlap"])
self.args["--c1"] = float(self.args["--c1"])
self.args["--c2"] = float(self.args["--c2"])
self.args["--splits"] = int(self.args["--splits"])
self.args["--e-filter"] = e_filter = float(self.args["--e-filter"])
if e_filter < 0 or e_filter > 1:
self.logger.error("Invalid value for `--e-filter`: {}", e_filter)
return 1
if self.args["--select"] is not None:
self.args["--select"] = float(self.args["--select"])
# Check the `--jobs`flag
self.args["--jobs"] = jobs = int(self.args["--jobs"])
if jobs == 0:
self.args["--jobs"] = multiprocessing.cpu_count()
return None
def __call__(self) -> int: # noqa: D102
seqs = self._load_sequences()
if self.args["loto"]:
splits = self._loto_splits(seqs)
else:
splits = self._kfold_splits(seqs)
self.logger.info("Performing cross-validation")
for i, (train_indices, test_indices) in enumerate(tqdm.tqdm(splits)):
# extract train data
train_data = [gene for i in train_indices for gene in seqs[i]]
# extract test data and erase existing probabilities
test_data = [copy.deepcopy(gene) for i in test_indices for gene in seqs[i]]
for gene in test_data:
gene.protein.domains = [d.with_probability(None) for d in gene.protein.domains]
# fit and predict the CRF for the current fold
crf = ClusterCRF(
self.args["--feature-type"],
algorithm="lbfgs",
overlap=self.args["--overlap"],
c1=self.args["--c1"],
c2=self.args["--c2"],
super()._check()
try:
self.feature_type = self._check_flag(
"--feature-type",
str,
lambda x: x in {"single", "overlap", "group"},
hint="'single', 'overlap' or 'group'"
)
crf.fit(train_data, jobs=self.args["--jobs"], select=self.args["--select"])
new_genes = crf.predict_probabilities(test_data, jobs=self.args["--jobs"])
with open(self.args["--output"], "a" if i else "w") as out:
frame = FeatureTable.from_genes(new_genes).to_dataframe()
frame.assign(fold=i).to_csv(out, header=i==0, sep="\t", index=False)
return 0
def _load_sequences(self):
self.logger.info("Loading the feature table")
with open(self.args["--features"]) as in_:
table = FeatureTable.load(in_)
self.logger.info("Converting data to genes")
gene_count = len(set(table.protein_id))
genes = list(tqdm.tqdm(table.to_genes(), total=gene_count, leave=False))
self.logger.info("Sorting genes by location")
self.overlap = self._check_flag(
"--overlap",
int,
lambda x: x > 0,
hint="positive integer",
)
self.c1 = self._check_flag("--c1", float, hint="real number")
self.c2 = self._check_flag("--c2", float, hint="real number")
self.splits = self._check_flag("--splits", int, lambda x: x>1, hint="integer greater than 1")
self.e_filter = self._check_flag(
"--e-filter",
float,
lambda x: 0 <= x <= 1,
hint="real number between 0 and 1"
)
self.select = self._check_flag(
"--select",
lambda x: x if x is None else float(x),
lambda x: x is None or 0 <= x <= 1,
hint="real number between 0 and 1"
)
self.jobs = self._check_flag(
"--jobs",
int,
lambda x: x >= 0,
hint="positive or null integer"
)
self.features = self._check_flag("--features")
self.loto = self.args["loto"]
self.output = self.args["--output"]
except InvalidArgument:
raise CommandExit(1)
# --
def _load_features(self):
self.info("Loading", "features table from file", repr(self.features))
with open(self.features) as in_:
return FeatureTable.load(in_)
def _convert_to_genes(self, features):
self.info("Converting", "features to genes")
gene_count = len(set(features.protein_id))
unit = "gene" if gene_count == 1 else "genes"
task = self.progress.add_task("Feature conversion", total=gene_count, unit=unit)
genes = list(self.progress.track(
features.to_genes(),
total=gene_count,
task_id=task
))
self.info("Sorting", "genes by genomic coordinates")
genes.sort(key=operator.attrgetter("source.id", "start", "end"))
self.info("Sorting", "domains by protein coordinates")
for gene in genes:
gene.protein.domains.sort(key=operator.attrgetter("start", "end"))
return genes
self.logger.info("Grouping genes by source sequence")
def _group_genes(self, genes):
self.info("Grouping", "genes by source sequence")
groups = itertools.groupby(genes, key=operator.attrgetter("source.id"))
seqs = [sorted(group, key=operator.attrgetter("start")) for _, group in groups]
if self.args["--shuffle"]:
self.logger.info("Shuffling training data sequences")
self.info("Shuffling training data sequences")
random.shuffle(seqs)
return seqs
def _loto_splits(self, seqs):
self.logger.info("Loading the clusters table")
with open(self.args["--clusters"]) as in_:
with open(self.clusters) as in_:
table = ClusterTable.load(in_)
index = { row.sequence_id: row.type for row in table }
if len(index) != len(table):
......@@ -182,5 +168,69 @@ class Cv(Command): # noqa: D101
return list(LeaveOneGroupOut().split(seqs, groups=groups))
def _kfold_splits(self, seqs):
k = self.args["--splits"]
return list(sklearn.model_selection.KFold(k).split(seqs))
return list(sklearn.model_selection.KFold(self.splits).split(seqs))
def _get_train_data(self, train_indices, seqs):
# extract train data
return [gene for i in train_indices for gene in seqs[i]]
def _get_test_data(self, test_indices, seqs):
# make a clean copy of the test data without gene probabilities
test_data = [copy.deepcopy(gene) for i in test_indices for gene in seqs[i]]
for gene in test_data:
gene.protein.domains = [d.with_probability(None) for d in gene.protein.domains]
return test_data
def _fit_predict(self, train_data, test_data):
# fit and predict the CRF for the current fold
crf = ClusterCRF(
self.feature_type,
algorithm="lbfgs",
overlap=self.overlap,
c1=self.c1,
c2=self.c2,
)
crf.fit(train_data, cpus=self.jobs, select=self.select)
return crf.predict_probabilities(test_data, cpus=self.jobs)
def _write_fold(self, fold, genes, append=False):
frame = FeatureTable.from_genes(genes).to_dataframe()
with open(self.output, "a" if append else "w") as out:
frame.assign(fold=fold).to_csv(out, header=not append, sep="\t", index=False)
# --
def execute(self, ctx: contextlib.ExitStack) -> int: # noqa: D102
try:
# check arguments and enter context
self._check()
ctx.enter_context(self.progress)
ctx.enter_context(patch_showwarnings(self._showwarnings))
# load features
features = self._load_features()
self.success("Loaded", len(features), "feature annotations")
genes = self._convert_to_genes(features)
self.success("Recoverd", len(genes), "genes from the feature annotations")
seqs = self._group_genes(genes)
self.success("Grouped", "genes into", len(seqs), "sequences")
# split CV folds
if self.loto:
splits = self._loto_splits(seqs)
else:
splits = self._kfold_splits(seqs)
# run CV
unit = "fold" if len(splits) == 1 else "folds"
task = self.progress.add_task(description="Cross-Validation", total=len(splits), unit=unit)
self.info("Performing cross-validation")
for i, (train_indices, test_indices) in enumerate(self.progress.track(splits, task_id=task)):
train_data = self._get_train_data(train_indices, seqs)
test_data = self._get_test_data(test_indices, seqs)
new_genes = self._fit_predict(train_data, test_data)
self._write_fold(i+1, genes, append=i==0)
except CommandExit as cexit:
return cexit.code
except KeyboardInterrupt:
self.error("Interrupted")
return -signal.SIGINT
else:
return 0
"""Implementation of the ``gecco embed`` subcommand.
"""
import contextlib
import csv
import itertools
import logging
......@@ -9,167 +10,223 @@ import multiprocessing.pool
import os
import pickle
import random
import signal
import typing
import warnings
import numpy
import pandas
import tqdm
from ._base import Command
from .._utils import numpy_error_context
from ._base import Command, InvalidArgument, CommandExit
from .._utils import numpy_error_context, in_context, patch_showwarnings
from ...hmmer import HMMER
class Embed(Command): # noqa: D101
summary = "embed BGC annotations into non-BGC contigs for training."
doc = f"""
gecco embed - {summary}
Usage:
gecco embed (-h | --help)
gecco embed [--bgc <data>]... [--no-bgc <data>]... [options]
Arguments:
--bgc <data> the path to the annotation table
containing BGC-only training instances.
--no-bgc <data> the path to the annotation table
containing non-BGC training instances.
Parameters:
-o <out>, --output <out> the file in which to write the
resulting embedding table.
[default: features.tsv]
--min-size <N> the minimum size for padding sequences.
[default: 500]
-e <e>, --e-filter <e> the e-value cutoff for domains to be
included. [default: 1e-5]
-j <jobs>, --jobs <jobs> the number of CPUs to use for
multithreading. Use 0 to use all of the
available CPUs. [default: 0]
--skip <N> skip the first N contigs while creating
the embedding. [default: 0]
"""
@classmethod
def doc(cls, fast=False): # noqa: D102
return f"""
gecco embed - {cls.summary}
Usage:
gecco embed [--bgc <data>]... [--no-bgc <data>]... [options]
Arguments:
--bgc <data> the path to the annotation table
containing BGC-only training instances.
--no-bgc <data> the path to the annotation table
containing non-BGC training instances.
Parameters:
-o <out>, --output <out> the prefix used for the output files
(which will be ``<prefix>.features.tsv``
and ``<prefix>.clusters.tsv``).
[default: embedding]
--min-size <N> the minimum size for padding sequences.
[default: 500]
-e <e>, --e-filter <e> the e-value cutoff for domains to be
included. [default: 1e-5]
-j <jobs>, --jobs <jobs> the number of CPUs to use for
multithreading. Use 0 to use all of the
available CPUs. [default: 0]
--skip <N> skip the first N contigs while creating
the embedding. [default: 0]
"""
def _check(self) -> typing.Optional[int]:
retcode = super()._check()
if retcode is not None:
return retcode
# Check value of numeric arguments
self.args["--skip"] = int(self.args["--skip"])
self.args["--min-size"] = int(self.args["--min-size"])
self.args["--e-filter"] = e_filter = float(self.args["--e-filter"])
if e_filter < 0 or e_filter > 1:
self.logger.error("Invalid value for `--e-filter`: {}", e_filter)
return 1
# Check the `--jobs`flag
self.args["--jobs"] = jobs = int(self.args["--jobs"])
if jobs == 0:
self.args["--jobs"] = multiprocessing.cpu_count()
# Check the input exists
for input_ in itertools.chain(self.args["--bgc"], self.args["--no-bgc"]):
if not os.path.exists(input_):
self.logger.error("could not locate input file: {!r}", input_)
return 1
return None
def __call__(self) -> int: # noqa: D102
# Load input
self.logger.info("Reading BGC and non-BGC feature tables")
def read_table(path: str) -> "pandas.DataFrame":
self.logger.debug("Reading table from {!r}", path)
return pandas.read_table(path, dtype={"domain": str})
# Read the non-BGC table, assign the Y column to `0`, sort and reshape
with multiprocessing.pool.ThreadPool(self.args["--jobs"]) as pool:
rows = pool.map(read_table, self.args["--no-bgc"])
no_bgc_df = pandas.concat(rows).assign(BGC="0")
self.logger.debug("Sorting non-BGC table")
super()._check()
try:
self.skip = self._check_flag("--skip", int, lambda x: x >= 0, "positive or null integer")
self.min_size = self._check_flag("--min-size", int, lambda x: x >= 0, "positive or null integer")
self.e_filter = self._check_flag("--e-filter", float, lambda x: 0 <= x <= 1, hint="real number between 0 and 1")
self.jobs = self._check_flag("--jobs", int, lambda x: x >= 0, hint="positive or null integer")
self.bgc = self.args["--bgc"]
self.no_bgc = self.args["--no-bgc"]
self.output = self.args["--output"]
except InvalidArgument:
raise CommandExit(1)
# ---
def _read_table(self, path: str) -> "pandas.DataFrame":
self.info("Reading", "table from", repr(path), level=2)
return pandas.read_table(path, dtype={"domain": str})
def _read_no_bgc(self):
self.info("Reading", "non-BGC features")
# Read the non-BGC table and assign the Y column to `0`
_jobs = os.cpu_count() if not self.jobs else self.jobs
with multiprocessing.pool.ThreadPool(_jobs) as pool:
rows = pool.map(self._read_table, self.no_bgc)
no_bgc_df = pandas.concat(rows).assign(bgc_probability=0.0)