Commits (58)
......@@ -24,7 +24,7 @@ jobs:
run: python -m pip install -r requirements.txt
- name: Build wheel distribution
run: python setup.py sdist
- name: Store built wheels
- name: Store built archive
uses: actions/upload-artifact@v2
with:
name: dist
......@@ -32,7 +32,7 @@ jobs:
wheel:
runs-on: ubuntu-latest
name: Build source distribution
name: Build wheel distribution
steps:
- name: Checkout code
uses: actions/checkout@v2
......@@ -48,7 +48,7 @@ jobs:
run: python -m pip install -r requirements.txt
- name: Build wheel distribution
run: python setup.py bdist_wheel
- name: Store built wheels
- name: Store built wheel
uses: actions/upload-artifact@v2
with:
name: dist
......
......@@ -27,17 +27,13 @@ jobs:
run: python setup.py build_data -f -r
- name: Compress Pfam HMM
run: gzip -c build/lib/gecco/hmmer/Pfam.h3m > Pfam.h3m.gz
- name: Compress Tigrfam HMM
run: gzip -c build/lib/gecco/hmmer/Tigrfam.h3m > Tigrfam.h3m.gz
- name: Upload HMM
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
files: |
Pfam.h3m.gz
Tigrfam.h3m.gz
files: Pfam.h3m.gz
chandler:
environment: GitHub Releases
......
......@@ -43,7 +43,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Update CI dependencies
run: python -m pip install -U pip coverage wheel numpy
run: python -m pip install -U pip coverage wheel numpy setuptools
- name: List project dependencies
run: python setup.py list_requirements
- name: Install project dependencies
......@@ -99,7 +99,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
- name: Update CI dependencies
run: python -m pip install -U pip coverage wheel
run: python -m pip install -U pip coverage wheel setuptools
- name: List project dependencies
run: python setup.py list_requirements
- name: Install project dependencies
......
......@@ -5,7 +5,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.7.0...master
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.8.0...master
## [v0.8.0] - 2021-07-03
[v0.8.0]: https://git.embl.de/grp-zeller/GECCO/compare/v0.7.0...v0.8.0
### Changed
- Retrain internal model using new sequence embeddings and remove broken/duplicate BGCs from MIBiG 2.0.
- Bump minimum `pyhmmer` version to `v0.4.0` to improve exception handling.
- Bump minimum `pyrodigal` version to `v0.5.0` to fix sequence decoding on some platforms.
- Use p-values instead of e-values to filter domains obtained with HMMER.
- `gecco cv` and `gecco train` now seed the RNG with a user-defined seed before shuffling rows of training data.
### Fixed
- Extraction of BGC compositions for the type predictor while training.
- `ClusterCRF.trained` failing to open an external model.
### Added
- `Domain.pvalue` attribute to access the p-value of a domain annotation.
- Mandatory `pvalue` column to `FeatureTable` objects.
- Support for loading several feature tables in `gecco train` and `gecco cv`.
- Warnings to `ClusterCRF.fit` when selecting uninformative features.
- `--correction` flag to `gecco train` and `gecco cv`, allowing to give a multiple testing correction method when computing p-values with the Fisher Exact Tests.
### Removed
- Outdated `gecco embed` command.
- Unused `--truncate` flag from the `gecco train` CLI.
- Tigrfam domains, which is not improving performance on the new training data.
## [v0.7.0] - 2021-05-31
[v0.7.0]: https://git.embl.de/grp-zeller/GECCO/compare/v0.6.3...v0.7.0
......
......@@ -123,7 +123,7 @@ html_theme_options = {
"navbar_pagenav": False,
# A list of tuples containing pages or urls to link to.
"navbar_links": [
("Repository", _parser.get("metadata", "home-page").strip(), True),
("Repository", project_urls["Repository"].strip(), True),
],
# + [
# (k, v, True)
......
......@@ -10,4 +10,4 @@ See Also:
__author__ = "Martin Larralde"
__license__ = "GPLv3"
__version__ = "0.7.0"
__version__ = "0.8.0"
......@@ -3,6 +3,7 @@
import contextlib
import functools
import io
import logging
import typing
import warnings
......@@ -16,6 +17,65 @@ if typing.TYPE_CHECKING:
None
]
class ProgressReader(io.RawIOBase):
"""A reader that updates a progress bar while it's being read from.
"""
@staticmethod
def scale_size(length):
for scale, unit in enumerate(["B", "kiB", "MiB", "GiB", "TiB"]):
if length > 1024:
length /= 1024
else:
break
return length, scale, unit
def __init__(self, handle, progress, task, scale=0):
self.handle = handle
self.progress = progress
self.task = task
self.scale = scale
def __enter__(self):
self.handle.__enter__()
return self
def __exit__(self, exc_val, exc_ty, tb):
self.handle.__exit__(exc_val, exc_ty, tb)
return False
def _update(self, length):
self.progress.update(self.task, advance=length / (1024 ** self.scale))
def readable(self):
return True
def seekable(self):
return False
def writable(self):
return False
def readline(self, size=-1):
line = self.handle.readline(size)
self._update(len(line))
return line
def readlines(self, hint=-1):
lines = self.handle.readlines(hint)
self._update(sum(map(len, lines)))
return lines
def read(self, size=-1):
block = self.handle.read(size)
self._update(len(block))
return block
def close(self):
self.handle.close()
@contextlib.contextmanager
def patch_showwarnings(new_showwarning: "ShowWarning") -> Iterator[None]:
"""Make a context patching `warnings.showwarning` with the given function.
......
......@@ -99,7 +99,7 @@ class Command(metaclass=abc.ABCMeta):
rich.progress.SpinnerColumn(finished_text="[green]:heavy_check_mark:[/]"),
"[progress.description]{task.description}",
rich.progress.BarColumn(bar_width=60),
"[progress.completed]{task.completed}/{task.total}",
"[progress.completed]{task.completed:{task.fields[precision]}}/{task.total:{task.fields[precision]}}",
"[progress.completed]{task.fields[unit]}",
"[progress.percentage]{task.percentage:>3.0f}%",
rich.progress.TimeElapsedColumn(),
......@@ -124,9 +124,12 @@ class Command(metaclass=abc.ABCMeta):
check: Optional[Callable[[_T], bool]] = None,
message: Optional[str] = None,
hint: Optional[str] = None,
optional: bool = False,
) -> _T:
_convert = (lambda x: x) if convert is None else convert
_check = (lambda x: True) if check is None else check
if optional and self.args[name] is None:
return None
try:
value = _convert(self.args[name])
if not _check(value):
......
......@@ -4,7 +4,9 @@
import contextlib
import errno
import glob
import gzip
import itertools
import io
import logging
import multiprocessing
import operator
......@@ -16,7 +18,12 @@ import signal
from typing import Any, Dict, Union, Optional, List, TextIO, Mapping
from ._base import Command, CommandExit, InvalidArgument
from .._utils import guess_sequences_format, in_context, patch_showwarnings
from .._utils import (
guess_sequences_format,
in_context,
patch_showwarnings,
ProgressReader,
)
class Annotate(Command): # noqa: D101
......@@ -50,7 +57,9 @@ class Annotate(Command): # noqa: D101
Parameters - Domain Annotation:
-e <e>, --e-filter <e> the e-value cutoff for protein domains
to be included. [default: 1e-5]
to be included.
-p <p>, --p-filter <p> the p-value cutoff for protein domains
to be included. [default: 1e-9]
Parameters - Debug:
--hmm <hmm> the path to one or more alternative
......@@ -60,7 +69,20 @@ class Annotate(Command): # noqa: D101
def _check(self) -> typing.Optional[int]:
super()._check()
try:
self.e_filter = self._check_flag("--e-filter", float, lambda x: 0 <= x <= 1, hint="real number between 0 and 1")
self.e_filter = self._check_flag(
"--e-filter",
float,
lambda x: x > 0,
hint="real number above 0",
optional=True,
)
self.p_filter = self._check_flag(
"--p-filter",
float,
lambda x: x > 0,
hint="real number above 0",
optional=True,
)
self.jobs = self._check_flag("--jobs", int, lambda x: x >= 0, hint="positive or null integer")
self.format = self._check_flag("--format")
self.genome = self._check_flag("--genome")
......@@ -75,11 +97,16 @@ class Annotate(Command): # noqa: D101
for path in self.hmm:
base = os.path.basename(path)
file = open(path, "rb")
if base.endswith(".gz"):
base, _ = os.path.splitext(base)
file = gzip.GzipFile(fileobj=file)
base, _ = os.path.splitext(base)
with pyhmmer.plan7.HMMFile(path) as hmm_file:
size = sum(1 for _ in hmm_file)
self.info("Counting", "profiles in HMM file", repr(path), level=1)
with file:
with pyhmmer.plan7.HMMFile(file) as hmm_file:
size = sum(1 for _ in hmm_file)
self.success("Found", size, "profiles in HMM file", repr(path), level=1)
yield HMM(
id=base,
version="?",
......@@ -94,6 +121,7 @@ class Annotate(Command): # noqa: D101
def _load_sequences(self):
from Bio import SeqIO
# guess format or use the one given in CLI
if self.format is not None:
format = self.format
self.info("Using", "user-provided sequence format", repr(format), level=2)
......@@ -102,13 +130,20 @@ class Annotate(Command): # noqa: D101
format = guess_sequences_format(self.genome)
self.success("Detected", "format of input as", repr(format), level=2)
# get filesize and unit
input_size = os.stat(self.genome).st_size
total, scale, unit = ProgressReader.scale_size(input_size)
task = self.progress.add_task("Loading sequences", total=total, unit=unit, precision=".1f")
# load sequences
self.info("Loading", "sequences from genomic file", repr(self.genome), level=1)
try:
sequences = list(SeqIO.parse(self.genome, format))
with ProgressReader(open(self.genome, "rb"), self.progress, task, scale) as f:
sequences = list(SeqIO.parse(io.TextIOWrapper(f), format))
except FileNotFoundError as err:
self.error("Could not find input file:", repr(self.genome))
raise CommandExit(e.errno) from err
except Exception as err:
except ValueError as err:
self.error("Failed to load sequences:", err)
raise CommandExit(getattr(err, "errno", 1)) from err
else:
......@@ -122,7 +157,7 @@ class Annotate(Command): # noqa: D101
orf_finder = PyrodigalFinder(metagenome=True, cpus=self.jobs)
unit = "contigs" if len(sequences) > 1 else "contig"
task = self.progress.add_task(description="ORFs finding", total=len(sequences), unit=unit)
task = self.progress.add_task(description="ORFs finding", total=len(sequences), unit=unit, precision="")
def callback(record, found, total):
self.success("Found", found, "genes in record", repr(record.id), level=2)
......@@ -137,9 +172,9 @@ class Annotate(Command): # noqa: D101
# Run all HMMs over ORFs to annotate with protein domains
hmms = list(self._custom_hmms() if self.hmm else embedded_hmms())
task = self.progress.add_task(description=f"HMM annotation", unit="HMMs", total=len(hmms))
task = self.progress.add_task(description=f"HMM annotation", unit="HMMs", total=len(hmms), precision="")
for hmm in self.progress.track(hmms, task_id=task, total=len(hmms)):
task = self.progress.add_task(description=f"{hmm.id} v{hmm.version}", total=hmm.size, unit="domains")
task = self.progress.add_task(description=f"{hmm.id} v{hmm.version}", total=hmm.size, unit="domains", precision="")
callback = lambda h, t: self.progress.update(task, advance=1)
self.info("Starting", f"annotation with [bold blue]{hmm.id} v{hmm.version}[/]", level=2)
features = PyHMMER(hmm, self.jobs).run(genes, progress=callback)
......@@ -150,14 +185,8 @@ class Annotate(Command): # noqa: D101
count = sum(1 for gene in genes for domain in gene.protein.domains)
self.success("Found", count, "domains across all proteins", level=1)
# Filter i-evalue
self.info("Filtering", "results with e-value under", self.e_filter, level=1)
key = lambda d: d.i_evalue < self.e_filter
for gene in genes:
gene.protein.domains = list(filter(key, gene.protein.domains))
count = sum(1 for gene in genes for domain in gene.protein.domains)
self.info("Using", "remaining", count, "domains", level=2)
# Filter i-evalue and p-value if required
genes = self._filter_domains(genes)
# Sort genes
self.info("Sorting", "genes by coordinates", level=2)
......@@ -167,6 +196,23 @@ class Annotate(Command): # noqa: D101
return genes
def _filter_domains(self, genes):
# Filter i-evalue and p-value if required
if self.e_filter is not None:
self.info("Filtering", "domains with e-value under", self.e_filter, level=1)
key = lambda d: d.i_evalue < self.e_filter
for gene in genes:
gene.protein.domains = list(filter(key, gene.protein.domains))
if self.p_filter is not None:
self.info("Filtering", "domains with p-value under", self.p_filter, level=1)
key = lambda d: d.pvalue < self.p_filter
for gene in genes:
gene.protein.domains = list(filter(key, gene.protein.domains))
if self.p_filter is not None or self.e_filter is not None:
count = sum(1 for gene in genes for domain in gene.protein.domains)
self.info("Using", "remaining", count, "domains", level=1)
return genes
def _write_feature_table(self, genes):
from ...model import FeatureTable
......
......@@ -85,11 +85,11 @@ class Convert(Command): # noqa: D101
# collect `*.clusters.tsv` files
cluster_files = glob.glob(os.path.join(self.input_dir, "*.clusters.tsv"))
unit = "table" if len(cluster_files) == 1 else "tables"
task = self.progress.add_task("Loading", total=len(cluster_files), unit=unit)
task = self.progress.add_task("Loading", total=len(cluster_files), unit=unit, precision="")
# load the original coordinates from the `*.clusters.tsv` files
coordinates = {}
types = {}
for cluster_file in self.progress.track(cluster_files, task_id=task):
for cluster_file in self.progress.track(cluster_files, task_id=task, precision=""):
cluster_fh = ctx.enter_context(open(cluster_file))
for row in ClusterTable.load(cluster_fh):
ty = ";".join(sorted(ty.name for ty in row.type.unpack()))
......@@ -99,10 +99,10 @@ class Convert(Command): # noqa: D101
# collect `*_clusters_{N}.gbk` files
gbk_files = glob.glob(os.path.join(self.input_dir, "*_cluster_*.gbk"))
unit = "file" if len(gbk_files) == 1 else "files"
task = self.progress.add_task("Converting", total=len(gbk_files), unit=unit)
task = self.progress.add_task("Converting", total=len(gbk_files), unit=unit, precision="")
done = 0
# rewrite GenBank files
for gbk_file in self.progress.track(gbk_files, task_id=task, total=len(gbk_files)):
for gbk_file in self.progress.track(gbk_files, task_id=task, total=len(gbk_files), precision=""):
# load record and ensure it comes from GECCO
record = Bio.SeqIO.read(gbk_file, "genbank")
if "GECCO-Data" not in record.annotations.get('structured_comment', {}):
......@@ -137,7 +137,7 @@ class Convert(Command): # noqa: D101
# collect `*_clusters_{N}.gbk` files
gbk_files = glob.glob(os.path.join(self.input_dir, "*_cluster_*.gbk"))
unit = "file" if len(gbk_files) == 1 else "files"
task = self.progress.add_task("Converting", total=len(gbk_files), unit=unit)
task = self.progress.add_task("Converting", total=len(gbk_files), unit=unit, precision="")
done = 0
# rewrite GenBank files
for gbk_file in self.progress.track(gbk_files, task_id=task, total=len(gbk_files)):
......@@ -162,7 +162,7 @@ class Convert(Command): # noqa: D101
# collect `*_clusters_{N}.gbk` files
gbk_files = glob.glob(os.path.join(self.input_dir, "*_cluster_*.gbk"))
unit = "file" if len(gbk_files) == 1 else "files"
task = self.progress.add_task("Converting", total=len(gbk_files), unit=unit)
task = self.progress.add_task("Converting", total=len(gbk_files), unit=unit, precision="")
done = 0
# rewrite GenBank files
for gbk_file in self.progress.track(gbk_files, task_id=task, total=len(gbk_files)):
......
......@@ -9,14 +9,19 @@ import os
import operator
import multiprocessing
import random
import signal
import typing
from typing import Any, Dict, Union, Optional, List, TextIO, Mapping
from ._base import Command, CommandExit, InvalidArgument
import docopt
from .._utils import patch_showwarnings
from ._base import Command, CommandExit, InvalidArgument
from .annotate import Annotate
from .train import Train
class Cv(Command): # noqa: D101
class Cv(Train): # noqa: D101
summary = "perform cross validation on a training set."
......@@ -26,8 +31,8 @@ class Cv(Command): # noqa: D101
gecco cv - {cls.summary}
Usage:
gecco cv kfold -f <table> [-c <data>] [options]
gecco cv loto -f <table> -c <data> [options]
gecco cv kfold --features <table>... --clusters <table> [options]
gecco cv loto --features <table>... --clusters <table> [options]
Arguments:
-f <data>, --features <table> a domain annotation table, used to
......@@ -46,7 +51,16 @@ class Cv(Command): # noqa: D101
Parameters - Domain Annotation:
-e <e>, --e-filter <e> the e-value cutoff for domains to
be included [default: 1e-5]
be included.
-p <p>, --p-filter <p> the p-value cutoff for domains
to be included. [default: 1e-9]
Parameters - Training Data:
--no-shuffle disable shuffling of the data
before fitting the model.
--seed <N> the seed to initialize the RNG
with for shuffling operations.
[default: 42]
Parameters - Training:
--c1 <C1> parameter for L1 regularisation.
......@@ -58,90 +72,42 @@ class Cv(Command): # noqa: D101
[default: group]
--overlap <N> how much overlap to consider if
features overlap. [default: 2]
--splits <N> number of folds for cross-validation
(if running `kfold`). [default: 10]
--select <N> fraction of most significant features
to select from the training data.
--shuffle enable shuffling of stratified rows.
--correction <method> the multiple test correction method
to use when computing significance
with multiple Fisher tests.
Parameters - Cross-validation:
--splits <N> number of folds for cross-validation
(if running `kfold`). [default: 10]
"""
def _check(self) -> typing.Optional[int]:
if not isinstance(self.args, docopt.DocoptExit):
self.args["--output-dir"] = "."
super()._check()
try:
self.feature_type = self._check_flag(
"--feature-type",
str,
lambda x: x in {"single", "overlap", "group"},
hint="'single', 'overlap' or 'group'"
)
self.overlap = self._check_flag(
"--overlap",
self.output = self._check_flag("--output", str)
self.splits = self._check_flag(
"--splits",
int,
lambda x: x > 0,
hint="positive integer",
)
self.c1 = self._check_flag("--c1", float, hint="real number")
self.c2 = self._check_flag("--c2", float, hint="real number")
self.splits = self._check_flag("--splits", int, lambda x: x>1, hint="integer greater than 1")
self.e_filter = self._check_flag(
"--e-filter",
float,
lambda x: 0 <= x <= 1,
hint="real number between 0 and 1"
)
self.select = self._check_flag(
"--select",
lambda x: x if x is None else float(x),
lambda x: x is None or 0 <= x <= 1,
hint="real number between 0 and 1"
)
self.jobs = self._check_flag(
"--jobs",
int,
lambda x: x >= 0,
hint="positive or null integer"
hint="positive integer"
)
self.features = self._check_flag("--features")
self.clusters = self._check_flag("--clusters")
self.loto = self.args["loto"]
self.output = self.args["--output"]
except InvalidArgument:
raise CommandExit(1)
# --
def _load_features(self):
from ...model import FeatureTable
self.info("Loading", "features table from file", repr(self.features))
with open(self.features) as in_:
return FeatureTable.load(in_)
def _convert_to_genes(self, features):
self.info("Converting", "features to genes")
gene_count = len(set(features.protein_id))
unit = "gene" if gene_count == 1 else "genes"
task = self.progress.add_task("Feature conversion", total=gene_count, unit=unit)
genes = list(self.progress.track(
features.to_genes(),
total=gene_count,
task_id=task
))
self.info("Sorting", "genes by genomic coordinates")
genes.sort(key=operator.attrgetter("source.id", "start", "end"))
self.info("Sorting", "domains by protein coordinates")
for gene in genes:
gene.protein.domains.sort(key=operator.attrgetter("start", "end"))
return genes
def _group_genes(self, genes):
self.info("Grouping", "genes by source sequence")
groups = itertools.groupby(genes, key=operator.attrgetter("source.id"))
seqs = [sorted(group, key=operator.attrgetter("start")) for _, group in groups]
if self.args["--shuffle"]:
self.info("Shuffling training data sequences")
if not self.no_shuffle:
self.info("Shuffling", "training data sequences")
random.shuffle(seqs)
return seqs
......@@ -172,11 +138,13 @@ class Cv(Command): # noqa: D101
import sklearn.model_selection
return list(sklearn.model_selection.KFold(self.splits).split(seqs))
def _get_train_data(self, train_indices, seqs):
@staticmethod
def _get_train_data(train_indices, seqs):
# extract train data
return [gene for i in train_indices for gene in seqs[i]]
def _get_test_data(self, test_indices, seqs):
@staticmethod
def _get_test_data(test_indices, seqs):
# make a clean copy of the test data without gene probabilities
test_data = [copy.deepcopy(gene) for i in test_indices for gene in seqs[i]]
for gene in test_data:
......@@ -186,15 +154,7 @@ class Cv(Command): # noqa: D101
def _fit_predict(self, train_data, test_data):
from ...crf import ClusterCRF
# fit and predict the CRF for the current fold
crf = ClusterCRF(
self.feature_type,
algorithm="lbfgs",
overlap=self.overlap,
c1=self.c1,
c2=self.c2,
)
crf.fit(train_data, cpus=self.jobs, select=self.select)
crf = self._fit_model(train_data)
return crf.predict_probabilities(test_data, cpus=self.jobs)
def _write_fold(self, fold, genes, append=False):
......@@ -212,11 +172,15 @@ class Cv(Command): # noqa: D101
self._check()
ctx.enter_context(self.progress)
ctx.enter_context(patch_showwarnings(self._showwarnings))
# seed RNG
self._seed_rng()
# load features
features = self._load_features()
self.success("Loaded", len(features), "feature annotations")
genes = self._convert_to_genes(features)
self.success("Recoverd", len(genes), "genes from the feature annotations")
del features
# load clusters and label genes inside clusters
clusters = self._load_clusters()
genes = self._label_genes(genes, clusters)
seqs = self._group_genes(genes)
self.success("Grouped", "genes into", len(seqs), "sequences")
# split CV folds
......@@ -226,7 +190,7 @@ class Cv(Command): # noqa: D101
splits = self._kfold_splits(seqs)
# run CV
unit = "fold" if len(splits) == 1 else "folds"
task = self.progress.add_task(description="Cross-Validation", total=len(splits), unit=unit)
task = self.progress.add_task(description="Cross-Validation", total=len(splits), unit=unit, precision="")
self.info("Performing cross-validation")
for i, (train_indices, test_indices) in enumerate(self.progress.track(splits, task_id=task)):
train_data = self._get_train_data(train_indices, seqs)
......
"""Implementation of the ``gecco embed`` subcommand.
"""
import contextlib
import csv
import itertools
import multiprocessing.pool
import os
import signal
import typing
from ._base import Command, InvalidArgument, CommandExit
from .._utils import numpy_error_context, in_context, patch_showwarnings
if typing.TYPE_CHECKING:
import pandas
class Embed(Command): # noqa: D101
summary = "embed BGC annotations into non-BGC contigs for training."
@classmethod
def doc(cls, fast=False): # noqa: D102
return f"""
gecco embed - {cls.summary}
Usage:
gecco embed [--bgc <data>]... [--no-bgc <data>]... [options]
Arguments:
--bgc <data> the path to the annotation table
containing BGC-only training instances.
--no-bgc <data> the path to the annotation table
containing non-BGC training instances.
Parameters:
-M <list>, --mapping <list> an arbitrary list of which BGC
should go into which contig. Ignores
``--min-size`` and ``--skip`` when
provided.
-o <out>, --output <out> the prefix used for the output files
(which will be ``<prefix>.features.tsv``
and ``<prefix>.clusters.tsv``).
[default: embedding]
--min-size <N> the minimum size for padding sequences.
[default: 500]
-e <e>, --e-filter <e> the e-value cutoff for domains to be
included. [default: 1e-5]
-j <jobs>, --jobs <jobs> the number of CPUs to use for
multithreading. Use 0 to use all of the
available CPUs. [default: 0]
--skip <N> skip the first N contigs while creating
the embedding. [default: 0]
"""
def _check(self) -> typing.Optional[int]:
super()._check()
try:
self.skip = self._check_flag("--skip", int, lambda x: x >= 0, "positive or null integer")
self.min_size = self._check_flag("--min-size", int, lambda x: x >= 0, "positive or null integer")
self.e_filter = self._check_flag("--e-filter", float, lambda x: 0 <= x <= 1, hint="real number between 0 and 1")
self.jobs = self._check_flag("--jobs", int, lambda x: x >= 0, hint="positive or null integer")
self.mapping = self.args["--mapping"]
self.bgc = self.args["--bgc"]
self.no_bgc = self.args["--no-bgc"]
self.output = self.args["--output"]
if self.mapping is not None:
self.min_size = 0
except InvalidArgument:
raise CommandExit(1)
# ---
def _read_table(self, path: str) -> "pandas.DataFrame":
import pandas
self.info("Reading", "table from", repr(path), level=2)
return pandas.read_table(path, dtype={"domain": str})
def _read_no_bgc(self):
import pandas
self.info("Reading", "non-BGC features")
# Read the non-BGC table and assign the Y column to `0`
_jobs = os.cpu_count() if not self.jobs else self.jobs
with multiprocessing.pool.ThreadPool(_jobs) as pool:
rows = pool.map(self._read_table, self.no_bgc)
no_bgc_df = pandas.concat(rows).assign(bgc_probability=0.0)
# sort and reshape
self.info("Sorting", "non-BGC features by genomic coordinates", level=2)
no_bgc_df.sort_values(by=["sequence_id", "start", "domain_start"], inplace=True)
return [
s
for _, s in no_bgc_df.groupby("sequence_id", sort=True)
if len(s.protein_id.unique()) > self.min_size
]
def _read_bgc(self):
import pandas
self.info("Reading", "BGC features")
# Read the BGC table, assign the Y column to `1`
_jobs = os.cpu_count() if not self.jobs else self.jobs
with multiprocessing.pool.ThreadPool(_jobs) as pool:
rows = pool.map(self._read_table, self.bgc)
bgc_df = pandas.concat(rows).assign(bgc_probability=1.0)
# sort and reshape
self.info("Sorting", "BGC features by genomic coordinates", level=2)
bgc_df.sort_values(by=["sequence_id", "start", "domain_start"], inplace=True)
return [s for _, s in bgc_df.groupby("sequence_id", sort=True)]
def _read_mapping(self):
import pandas
if self.mapping is not None:
mapping = pandas.read_table(self.mapping)
return { t.bgc_id:t.sequence_id for t in mapping.itertuples() }
return None
def _check_count(self, no_bgc_list, bgc_list):
no_bgc_count, bgc_count = len(no_bgc_list) - self.skip, len(bgc_list)
if no_bgc_count < bgc_count:
self.warn("Not enough non-BGC sequences to embed BGCs:", no_bgc_count, bgc_count)
def _embed(
self,
no_bgc: "pandas.DataFrame",
bgc: "pandas.DataFrame",
) -> "pandas.DataFrame":
import pandas
import numpy
by_prots = [s for _, s in no_bgc.groupby("protein_id", sort=False)]
# cut the input in half to insert the bgc in the middle
index_half = len(by_prots) // 2
before, after = by_prots[:index_half], by_prots[index_half:]
# find the position at which the BGC is being inserted
insert_position = (before[-1].end.values[0] + after[0].start.values[0]) // 2
bgc_length = bgc.end.max() - bgc.start.min()
bgc = bgc.assign(
start=bgc.start - bgc.start.min() + insert_position,
end=bgc.end - bgc.start.min() + insert_position,
)
# shift all the 3' genes after the BGC
after = [
x.assign(start=x.start + bgc_length, end=x.end + bgc_length)
for x in after
]
# concat the embedding together and filter by e_value
embed = pandas.concat(before + [bgc] + after, sort=False)
embed = embed.reset_index(drop=True)
embed = embed[embed["i_evalue"] < self.e_filter]
# add additional columns based on info from BGC and non-BGC
with numpy_error_context(numpy, divide="ignore"):
bgc_id = bgc["sequence_id"].values[0]
sequence_id = no_bgc["sequence_id"].apply(lambda x: x).values[0]
embed = embed.assign(sequence_id=sequence_id, BGC_id=bgc_id)
# return the embedding
self.success("Finished", "embedding", repr(bgc_id), "into", repr(sequence_id), level=2)
return embed
def _make_embeddings(self, no_bgc_list, bgc_list, mapping):
import pandas
self.info("Embedding", len(bgc_list), "BGCs into", len(no_bgc_list), "contigs")
_jobs = os.cpu_count() if not self.jobs else self.jobs
unit = "BGC" if len(bgc_list) == 1 else "BGCs"
task = self.progress.add_task("Embedding", unit=unit, total=len(bgc_list))
if mapping is None:
it = zip(itertools.islice(no_bgc_list, self.skip, None), bgc_list)
else:
no_bgc_index = {x.sequence_id.values[0]:x for x in no_bgc_list}
it = [(no_bgc_index[mapping[ bgc.sequence_id.values[0] ]], bgc) for bgc in bgc_list]
embeddings = pandas.concat([
self._embed(*args)
for args in self.progress.track(it, task_id=task, total=len(bgc_list))
])
embeddings.sort_values(by=["sequence_id", "start", "domain_start"], inplace=True)
return embeddings
def _write_clusters(self, embeddings):
self.info("Writing", "clusters table to file", repr(f"{self.output}.clusters.tsv"))
with open(f"{self.output}.clusters.tsv", "w") as f:
writer = csv.writer(f, dialect="excel-tab")
writer.writerow([
"sequence_id", "bgc_id", "start", "end", "average_p", "max_p",
"type", "alkaloid_probability", "polyketide_probability",
"ripp_probability", "saccharide_probability",
"terpene_probability", "nrp_probability",
"other_probability", "proteins", "domains"
])
positives = embeddings[embeddings.bgc_probability == 1.0]
for sequence_id, domains in positives.groupby("sequence_id"):
# ty = domains.BGC_type.values[0]
writer.writerow([
sequence_id,
domains.BGC_id.values[0],
domains.start.min(),
domains.end.max(),
domains.bgc_probability.values[0],
domains.bgc_probability.values[0],
"Unknown",
# int("Alkaloid" in ty),
# int("Polyketide" in ty),
# int("RiPP" in ty),
# int("Saccharide" in ty),
# int("Terpene" in ty),
# int("NRP" in ty),
# int("Other" in ty),
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
";".join(sorted(set(domains.protein_id))),
";".join(sorted(set(domains.domain)))
])
def _write_features(self, embeddings):
self.info("Writing", "features table to file", repr(f"{self.output}.features.tsv"))
hmm_mapping = dict(PF="Pfam", TI="Tigrfam", PT="Panther", SM="smCOGs", RF="Resfams")
columns = [
'sequence_id', 'protein_id', 'start', 'end', 'strand', 'domain',
'hmm', 'i_evalue', 'domain_start', 'domain_end', 'bgc_probability'
]
embeddings[columns].to_csv(f"{self.output}.features.tsv", sep="\t", index=False)
# ---
def execute(self, ctx: contextlib.ExitStack) -> int: # noqa: D102
try:
# check arguments and enter context
self._check()
ctx.enter_context(self.progress)
ctx.enter_context(patch_showwarnings(self._showwarnings))
# load inputs
no_bgc_list = self._read_no_bgc()
bgc_list = self._read_bgc()
mapping = self._read_mapping()
if mapping is None:
self._check_count(no_bgc_list, bgc_list)
# make embeddings
embeddings = self._make_embeddings(no_bgc_list, bgc_list, mapping)
# write outputs
self._write_features(embeddings)
self._write_clusters(embeddings)
except CommandExit as cexit:
return cexit.code
except KeyboardInterrupt:
self.error("Interrupted")
return -signal.SIGINT
except Exception as err:
self.progress.stop()
raise
else:
return 0
......@@ -61,7 +61,9 @@ class Run(Annotate): # noqa: D101
Parameters - Domain Annotation:
-e <e>, --e-filter <e> the e-value cutoff for protein domains
to be included. [default: 1e-5]
to be included.
-p <p>, --p-filter <p> the p-value cutoff for protein domains
to be included. [default: 1e-9]
Parameters - Cluster Detection:
-c, --cds <N> the minimum number of coding sequences a
......@@ -84,7 +86,20 @@ class Run(Annotate): # noqa: D101
Command._check(self)
try:
self.cds = self._check_flag("--cds", int, lambda x: x > 0, hint="positive integer")
self.e_filter = self._check_flag("--e-filter", float, lambda x: 0 <= x <= 1, hint="real number between 0 and 1")
self.e_filter = self._check_flag(
"--e-filter",
float,
lambda x: x > 0,
hint="real number above 0",
optional=True,
)
self.p_filter = self._check_flag(
"--p-filter",
float,
lambda x: x > 0,
hint="real number above 0",
optional=True,
)
if self.args["--threshold"] is None:
self.threshold = 0.4 if self.args["--postproc"] == "gecco" else 0.6
else:
......@@ -109,7 +124,7 @@ class Run(Annotate): # noqa: D101
os.makedirs(self.output_dir, exist_ok=True)
except OSError as err:
self.error("Could not create output directory: {}", err)
raise CommandExit(e.errno) from err
raise CommandExit(err.errno) from err
# Check if output files already exist
base, _ = os.path.splitext(os.path.basename(self.genome))
......@@ -132,7 +147,7 @@ class Run(Annotate): # noqa: D101
self.info("Predicting", "cluster probabilitites with the CRF model", level=1)
unit = "genes" if len(genes) > 1 else "gene"
task = self.progress.add_task("Prediction", total=len(genes), unit=unit)
task = self.progress.add_task("Prediction", total=len(genes), unit=unit, precision="")
return list(crf.predict_probabilities(
self.progress.track(genes, task_id=task, total=len(genes)),
cpus=self.jobs
......@@ -155,7 +170,7 @@ class Run(Annotate): # noqa: D101
total = len({gene.source.id for gene in genes})
unit = "contigs" if total > 1 else "contig"
task = self.progress.add_task("Segmentation", total=total, unit=unit)
task = self.progress.add_task("Segmentation", total=total, unit=unit, precision="")
clusters = []
gene_groups = itertools.groupby(genes, lambda g: g.source.id)
......@@ -171,7 +186,7 @@ class Run(Annotate): # noqa: D101
self.info("Predicting", "BGC types", level=1)
unit = "cluster" if len(clusters) == 1 else "clusters"
task = self.progress.add_task("Type prediction", total=len(clusters), unit=unit)
task = self.progress.add_task("Type prediction", total=len(clusters), unit=unit, precision="")
clusters_new = []
classifier = TypeClassifier.trained(self.model)
......@@ -292,7 +307,8 @@ class Run(Annotate): # noqa: D101
self._write_clusters(clusters)
if self.antismash_sideload:
self._write_sideload_json(clusters)
self.success("Found", len(clusters), "biosynthetic gene clusters", level=0)
unit = "cluster" if len(clusters) == 1 else "clusters"
self.success("Found", len(clusters), "biosynthetic gene", unit, level=0)
except CommandExit as cexit:
return cexit.code
except KeyboardInterrupt:
......
"""Implementation of the ``gecco train`` subcommand.
"""
import collections
import contextlib
import csv
import hashlib
import io
import itertools
import os
import operator
import pickle
import random
import signal
import typing
from typing import Any, Dict, Union, Optional, List, TextIO, Mapping
from .._utils import in_context, patch_showwarnings, ProgressReader
from ._base import Command, CommandExit, InvalidArgument
from .._utils import in_context, patch_showwarnings
from .annotate import Annotate
if typing.TYPE_CHECKING:
from ...crf import ClusterCRF
from ...model import Cluster, Gene, FeatureTable, ClusterTable
class Train(Command): # noqa: D101
......@@ -21,12 +29,12 @@ class Train(Command): # noqa: D101
summary = "train the CRF model on an embedded feature table."
@classmethod
def doc(cls, fast=False): # noqa: D102
def doc(cls, fast: bool = False) -> str: # noqa: D102
return f"""
gecco train - {cls.summary}
Usage:
gecco train --features <table> --clusters <table> [options]
gecco train --features <table>... --clusters <table> [options]
Arguments:
-f <data>, --features <table> a domain annotation table, used to
......@@ -44,7 +52,16 @@ class Train(Command): # noqa: D101
Parameters - Domain Annotation:
-e <e>, --e-filter <e> the e-value cutoff for domains to
be included [default: 1e-5]
be included.
-p <p>, --p-filter <p> the p-value cutoff for domains to
be included. [default: 1e-9]
Parameters - Training Data:
--no-shuffle disable shuffling of the data
before fitting the model.
--seed <N> the seed to initialize the RNG
with for shuffling operations.
[default: 42]
Parameters - Training:
--c1 <C1> parameter for L1 regularisation.
......@@ -54,18 +71,19 @@ class Train(Command): # noqa: D101
--feature-type <type> how features should be extracted
(single, overlap, or group).
[default: group]
--truncate <N> the maximum number of rows to use from
the training set.
--overlap <N> how much overlap to consider if
features overlap. [default: 2]
--no-shuffle disable shuffling of the data before
fitting the model.
--select <N> fraction of most significant features
to select from the training data.
--correction <method> the multiple test correction method
to use when computing significance
with multiple Fisher tests.
"""
def _check(self) -> typing.Optional[int]:
from ...crf.select import _CORRECTION_METHODS
super()._check()
try:
self.feature_type = self._check_flag(
......@@ -74,12 +92,6 @@ class Train(Command): # noqa: D101
lambda x: x in {"single", "overlap", "group"},
hint="'single', 'overlap' or 'group'"
)
self.truncate = self._check_flag(
"--truncate",
lambda x: x if x is None else int(x),
lambda x: x is None or x > 0,
hint="positive integer"
)
self.overlap = self._check_flag(
"--overlap",
int,
......@@ -91,14 +103,30 @@ class Train(Command): # noqa: D101
self.e_filter = self._check_flag(
"--e-filter",
float,
lambda x: 0 <= x <= 1,
hint="real number between 0 and 1"
lambda x: x > 0,
hint="real number above 0",
optional=True,
)
self.p_filter = self._check_flag(
"--p-filter",
float,
lambda x: x > 0,
hint="real number above 0",
optional=True,
)
self.select = self._check_flag(
"--select",
lambda x: x if x is None else float(x),
lambda x: x is None or 0 <= x <= 1,
hint="real number between 0 and 1"
float,
lambda x: 0 <= x <= 1,
hint="real number between 0 and 1",
optional=True,
)
self.correction = self._check_flag(
"--correction",
str,
lambda m: m in _CORRECTION_METHODS,
hint="one of {}".format(", ".join(sorted(_CORRECTION_METHODS))),
optional=True
)
self.jobs = self._check_flag(
"--jobs",
......@@ -106,16 +134,20 @@ class Train(Command): # noqa: D101
lambda x: x >= 0,
hint="positive or null integer"
)
self.features = self._check_flag("--features")
self.no_shuffle = self._check_flag("--no-shuffle", bool)
self.seed = self._check_flag("--seed", int)
self.output_dir = self._check_flag("--output-dir", str)
self.features = self._check_flag("--features", str)
self.features = self._check_flag("--features", list)
self.clusters = self._check_flag("--clusters", str)
except InvalidArgument:
raise CommandExit(1)
# ---
def _seed_rng(self):
self.info("Seeding", "the random number generator with seed", self.seed, level=2)
random.seed(self.seed)
def _make_output_directory(self) -> None:
# Make output directory
self.info("Using", "output folder", repr(self.output_dir), level=1)
......@@ -123,7 +155,7 @@ class Train(Command): # noqa: D101
os.makedirs(self.output_dir, exist_ok=True)
except OSError as err:
self.error("Could not create output directory: {}", err)
raise CommandExit(e.errno) from err
raise CommandExit(err.errno) from err
# Check if output files already exist
files = [
"model.pkl",
......@@ -137,19 +169,33 @@ class Train(Command): # noqa: D101
self.warn("Output folder contains files that will be overwritten")
break
def _load_features(self):
def _load_features(self) -> "FeatureTable":
from ...model import FeatureTable
self.info("Loading", "features table from file", repr(self.features))
with open(self.features) as in_:
return FeatureTable.load(in_)
def _convert_to_genes(self, features):
features = FeatureTable()
for filename in self.features:
try:
# get filesize and unit
input_size = os.stat(filename).st_size
total, scale, unit = ProgressReader.scale_size(input_size)
task = self.progress.add_task("Loading features", total=total, unit=unit, precision=".1f")
# load features
self.info("Loading", "features table from file", repr(filename))
with ProgressReader(open(filename, "rb"), self.progress, task, scale) as in_:
features += FeatureTable.load(io.TextIOWrapper(in_))
except FileNotFoundError as err:
self.error("Could not find feature file:", repr(filename))
raise CommandExit(err.errno) from err
self.success("Loaded", "a total of", len(features), "features", level=1)
return features
def _convert_to_genes(self, features: "FeatureTable") -> List["Gene"]:
self.info("Converting", "features to genes")
gene_count = len(set(features.protein_id))
unit = "gene" if gene_count == 1 else "genes"
task = self.progress.add_task("Feature conversion", total=gene_count, unit=unit)
task = self.progress.add_task("Converting features", total=gene_count, unit=unit, precision="")
genes = list(self.progress.track(
features.to_genes(),
......@@ -157,6 +203,9 @@ class Train(Command): # noqa: D101
task_id=task
))
# filter domains out
Annotate._filter_domains(self, genes)
self.info("Sorting", "genes by genomic coordinates")
genes.sort(key=operator.attrgetter("source.id", "start", "end"))
self.info("Sorting", "domains by protein coordinates")
......@@ -164,11 +213,13 @@ class Train(Command): # noqa: D101
gene.protein.domains.sort(key=operator.attrgetter("start", "end"))
return genes
def _fit_model(self, genes):
def _fit_model(self, genes: List["Gene"]) -> "ClusterCRF":
from ...crf import ClusterCRF
self.info("Creating" f"the CRF in {self.feature_type} mode", level=2)
self.info("Using" f"hyperparameters C1={self.c1}, C2={self.c2}", level=2)
self.info("Creating", f"the CRF in [bold blue]{self.feature_type}[/] mode", level=1)
self.info("Using", f"hyperparameters C1={self.c1}, C2={self.c2}", level=1)
if self.select is not None:
self.info("Using", f"Fisher Exact Test significance threshold of {self.select}", level=1)
crf = ClusterCRF(
self.feature_type,
algorithm="lbfgs",
......@@ -177,10 +228,16 @@ class Train(Command): # noqa: D101
c2=self.c2,
)
self.info("Fitting", "the CRF model to the training data")
crf.fit(genes, select=self.select, shuffle=not self.no_shuffle, cpus=self.jobs)
crf.fit(
genes,
select=self.select,
shuffle=not self.no_shuffle,
correction_method=self.correction,
cpus=self.jobs
)
return crf
def _save_model(self, crf):
def _save_model(self, crf: "ClusterCRF") -> None:
model_out = os.path.join(self.output_dir, "model.pkl")
self.info("Pickling", "the model to", repr(model_out))
with open(model_out, "wb") as out:
......@@ -196,7 +253,7 @@ class Train(Command): # noqa: D101
with open(f"{model_out}.md5", "w") as out_hash:
out_hash.write(hasher.hexdigest())
def _save_transitions(self, crf):