Commits (65)
......@@ -21,10 +21,10 @@ variables:
after_script:
- python -m coverage combine
- python -m coverage report
- python ci/gitlab/after_script.hmmfilter.py
artifacts:
paths:
- ci/artifacts
# - python ci/gitlab/after_script.hmmfilter.py
# artifacts:
# paths:
# - ci/artifacts
# --- Stages -------------------------------------------------------------------
......
......@@ -5,10 +5,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.3.0...master
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.4.0...master
## [v0.4.0] - 2020-08-06
[v0.4.0]: https://git.embl.de/grp-zeller/GECCO/compare/v0.3.0...v0.4.0
### Added
- `gecco.model.ProductType` enum to model the biosynthetic class of a BGC.
### Removed
- `pandas` interaction from internal data model.
- `ClusterCRF` code specific to cross-validation.
### Changed
- `pandas`, `fisher` and `statsmodels` dependencies are now optional.
- `gecco train` command expects a cluster table in addition to the feature
table to know the types of the input BGCs.
## [v0.3.0] - 2020-08-03
[v0.2.2]: https://git.embl.de/grp-zeller/GECCO/compare/v0.2.2...v0.3.0
[v0.3.0]: https://git.embl.de/grp-zeller/GECCO/compare/v0.2.2...v0.3.0
### Changed
- Replaced Nearest-Neighbours classifier with Random Forest to perform type
prediction for candidate BGCs.
......
......@@ -5,4 +5,5 @@ include gecco/_version.txt
recursive-include gecco/crf *.pkl *.pkl.md5
recursive-include gecco/hmmer *.ini
recursive-include gecco/knn *.tsv *.npz
recursive-include gecco/types *.tsv *.npz
recursive-include gecco/interpro *.json.gz
......@@ -3,9 +3,12 @@ import gzip
import re
import os
import io
import sys
import tqdm
from gecco.crf import ClusterCRF
import pkg_resources
sys.path.insert(0, os.path.realpath(os.path.join(__file__, "..", "..", "..")))
from gecco.hmmer import embedded_hmms
from gecco.interpro import InterPro
......@@ -15,21 +18,22 @@ os.makedirs(os.path.join("ci", "artifacts"), exist_ok=True)
# Load InterPro to know how many entries we have to process
interpro = InterPro.load()
# Load the internal CRF model and compile a regex that matches the domains
# Load the domains used by the CRF and compile a regex that matches the domains
# known to the CRF (i.e. useful domains for us to annotate with)
crf = ClusterCRF.trained()
rx = re.compile("|".join(crf.model.attributes_).encode("utf-8"))
with pkg_resources.resource_stream("gecco.types", "domains.tsv") as f:
domains = [ domain.strip() for domain in f ]
rx = re.compile(b"|".join(domains))
# Filter the hmms
for hmm in embedded_hmms():
out = os.path.join("ci", "artifacts", "{}.hmm.gz".format(hmm.id))
in_ = os.path.join("ci", "cache", "{}.{}.hmm.gz".format(hmm.id, hmm.version))
size = sum(1 for e in interpro.entries if e.source_database.upper().startswith(hmm.id.upper()))
pbar = tqdm.tqdm(desc=hmm.id, total=size)
with contextlib.ExitStack() as ctx:
pbar = ctx.enter_context(pbar)
src = ctx.enter_context(gzip.open(hmm.path, "rb"))
src = ctx.enter_context(gzip.open(in_, "rb"))
dst = ctx.enter_context(gzip.open(out, "wb"))
blocklines = []
......
......@@ -20,7 +20,7 @@ pip install -U coverage tqdm
mkdir -p ci/cache
mkdir -p build/lib/gecco/data/hmms
if [ "$CI_SERVER" == "true" ]; then
if [ "$CI_SERVER" = "yes" ]; then
QUIET="-q"
else
QUIET=""
......
......@@ -5,6 +5,8 @@ set -e
. $(dirname $(dirname $0))/functions.sh
python setup.py bdist_wheel
pip install -U dist/*.whl
WHEEL=$(python setup.py --name)-$(python setup.py --version)-py2.py3-none-any.whl
pip install -U "dist/$WHEEL[train]"
python -m coverage run -p -m unittest discover -vv
......@@ -20,5 +20,5 @@ p {
}
.class dd {
margin-left: 5%;
margin-left: 2%;
}
......@@ -9,7 +9,7 @@ API Reference
hmmer
crf
refine
knn
types
Data Model
......@@ -20,7 +20,6 @@ Data Model
.. autosummary::
:nosignatures:
Hmm
Strand
Domain
Protein
......@@ -75,9 +74,9 @@ BGC Extraction
Type Prediction
---------------
.. currentmodule:: gecco.knn
.. currentmodule:: gecco.types
.. autosummary::
:nosignatures:
ClusterKNN
TypeClassifier
......@@ -5,8 +5,6 @@ Data Model
.. automodule:: gecco.model
.. autoclass:: Hmm
.. autoclass:: Domain
.. autoclass:: Protein
......
Type Prediction
===============
.. currentmodule:: gecco.knn
.. automodule:: gecco.knn
.. currentmodule:: gecco.types
.. automodule:: gecco.types
.. autoclass:: ClusterKNN
.. autoclass:: TypeClassifier
:members:
......@@ -197,6 +197,7 @@ intersphinx_mapping = {
"pandas": ("https://pandas.pydata.org/docs/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"statsmodels": ("https://tedboy.github.io/statsmodels_doc/", None),
"biopython": ("https://biopython.org/docs/1.77/api/", None),
}
# -- Options for todo extension ----------------------------------------------
......
import abc
import errno
import io
import subprocess
import typing
from typing import Iterable, Optional, Type
from typing import Iterable, Optional, Type, TextIO
from subprocess import DEVNULL
from ._meta import classproperty
......@@ -52,3 +53,17 @@ class BinaryRunner(metaclass=abc.ABCMeta):
def __init__(self) -> None:
self.check_binary()
class Dumpable(metaclass=abc.ABCMeta):
"""A metaclass for objects that can be dumped to a text file.
"""
@abc.abstractmethod
def dump(self, fh: TextIO) -> None:
raise NotImplementedError
def dumps(self) -> str:
s = io.StringIO()
self.dump(s)
return s.getvalue()
"""Share metaprogramming helpers for GECCO.
"""Shared metaprogramming helpers for GECCO.
"""
import abc
import functools
import importlib
import operator
import typing
from typing import Callable
from multiprocessing.pool import Pool
from typing import Any, Callable, Iterable, List, Tuple, Optional, Type
if typing.TYPE_CHECKING:
from types import TracebackType
_S = typing.TypeVar("_S")
_T = typing.TypeVar("_T")
_A = typing.TypeVar("_A")
_R = typing.TypeVar("_R")
# _F = typing.TypeVar("_F", bound=Callable[[_A], _R])
class classproperty(property):
"""A class property decorator.
......@@ -35,3 +43,79 @@ class classproperty(property):
def __get__(self, obj: object, owner: "_S") -> "_T": # type: ignore
return self.f(owner)
class requires:
"""A decorator for functions that require optional dependencies.
"""
def __init__(self, module_name):
self.module_name = module_name
try:
self.module = importlib.import_module(module_name)
except ImportError as err:
self.module = err
def __call__(self, func):
if isinstance(self.module, ImportError):
@functools.wraps(func)
def newfunc(*args, **kwargs):
msg = f"calling {func.__qualname__} requires module {self.module.name}"
raise RuntimeError(msg) from self.module
else:
newfunc = func
basename = self.module_name.split(".")[-1]
newfunc.__globals__[basename] = self.module
return newfunc
class OrderedPoolWrapper:
"""A `Pool` wrapper that returns results in the order they were given.
"""
class _OrderedFunc:
def __init__(self, inner: Callable[["_A"], "_R"], star: bool = False) -> None:
self.inner = inner
self.star = star
def __call__(self, args: Tuple[int, "_A"]) -> Tuple[int, "_R"]:
i, other = args
if self.star:
return i, self.inner(*other) # type: ignore
else:
return i, self.inner(other)
def __init__(self, inner: Pool) -> None:
self.inner = inner
def __enter__(self) -> "OrderedPoolWrapper":
self.inner.__enter__()
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional["TracebackType"],
) -> Optional[bool]:
return self.inner.__exit__(exc_type, exc_value, traceback)
def map(self, func: Callable[["_A"], "_R"], it: Iterable["_A"]) -> List["_R"]:
wrapped_it = enumerate(it)
wrapped_func = self._OrderedFunc(func)
results = self.inner.map(wrapped_func, wrapped_it)
results.sort(key=operator.itemgetter(0))
return list(map(operator.itemgetter(1), results))
def starmap(self, func: Callable[..., "_R"], it: Iterable[Iterable[Any]]) -> List["_R"]:
wrapped_it = enumerate(it)
wrapped_func = self._OrderedFunc(func, star=True)
results = self.inner.map(wrapped_func, wrapped_it)
results.sort(key=operator.itemgetter(0))
return list(map(operator.itemgetter(1), results))
......@@ -23,20 +23,17 @@ class Main(Command):
@classmethod
def _get_subcommands(cls) -> Mapping[str, Type[Command]]:
return {
cmd.name: cmd.load() for cmd in pkg_resources.iter_entry_points(__parent__)
}
commands = {}
for cmd in pkg_resources.iter_entry_points(__parent__):
try:
commands[cmd.name] = cmd.load()
except pkg_resources.DistributionNotFound as err:
pass
return commands
@classmethod
def _get_subcommand(cls, name: str) -> Optional[Type[Command]]:
try:
return next(
typing.cast(Type[Command], ep.load())
for ep in pkg_resources.iter_entry_points(__parent__)
if ep.name == name
)
except StopIteration:
return None
return cls._get_subcommands().get(name)
@classproperty
def doc(cls) -> str: # type: ignore
......
......@@ -2,20 +2,21 @@
"""
import functools
import itertools
import os
import operator
import multiprocessing
import random
import typing
from typing import List
import pandas
import tqdm
import sklearn.model_selection
from ._base import Command
from ...model import FeatureTable
from ...crf import ClusterCRF
if typing.TYPE_CHECKING:
from pandas import DataFrame
class Cv(Command): # noqa: D101
......@@ -26,7 +27,6 @@ class Cv(Command): # noqa: D101
Usage:
gecco cv (-h | --help)
gecco cv kfold -i <data> [-w <col>]... [-f <col>]... [options]
gecco cv loto -i <data> [-w <col>]... [-f <col>]... [options]
Arguments:
-i <data>, --input <data> a domain annotation table with regions
......@@ -116,66 +116,50 @@ class Cv(Command): # noqa: D101
return None
def __call__(self) -> int: # noqa: D102
# --- LOADING AND PREPROCESSING --------------------------------------
self.logger.info("Loading the data from {!r}", self.args["--input"])
data_tbl = pandas.read_csv(self.args["--input"], sep="\t", encoding="utf-8")
self.logger.debug(
"Filtering results with e-value under {}", self.args["--e-filter"]
)
data_tbl = data_tbl[data_tbl["i_Evalue"] < self.args["--e-filter"]]
self.logger.debug("Splitting input by column {!r}", self.args["--split-col"])
data: List["DataFrame"] = [
s for _, s in data_tbl.groupby(self.args["--split-col"])
]
# --- LOADING AND PREPROCESSING --------------------------------------
# Load the table
self.logger.info("Loading the data")
with open(self.args["--input"]) as in_:
table = FeatureTable.load(in_)
# Converting table to genes and sort by location
self.logger.info("Converting data to genes")
gene_count = len(set(table.protein_id))
genes = list(tqdm.tqdm(table.to_genes(), total=gene_count))
del table
self.logger.info("Sorting genes by location")
genes.sort(key=operator.attrgetter("source.id", "start", "end"))
for gene in genes:
gene.protein.domains.sort(key=operator.attrgetter("start", "end"))
# group by sequence
groups = itertools.groupby(genes, key=operator.attrgetter("source.id"))
seqs = [sorted(group, key=operator.attrgetter("start")) for _, group in groups]
# shuffle if required
if self.args["--shuffle"]:
self.logger.debug("Shuffling rows")
random.shuffle(data)
# --- CROSS VALIDATION -----------------------------------------------
crf = ClusterCRF(
feature_columns=self.args["--feature-cols"],
weight_columns=self.args["--weight-cols"],
feature_type=self.args["--feature-type"],
label_column=self.args["--y-col"],
overlap=self.args["--overlap"],
algorithm="lbfgs",
c1=self.args["--c1"],
c2=self.args["--c2"],
)
if self.args["loto"]:
cv_type = "loto"
cross_validate = crf.loto_cv
elif self.args["kfold"]:
cv_type = "kfold"
cross_validate = functools.partial(crf.cv, k=self.args["--splits"])
random.shuffle(seqs)
# --- CROSS-VALIDATION ------------------------------------------------
k = 10
splits = list(sklearn.model_selection.KFold(k).split(seqs))
new_genes = []
self.logger.info("Performing cross-validation")
results = pandas.concat(
cross_validate(
data,
self.args["--strat-col"],
trunc=self.args["--truncate"],
select=self.args["--select"],
jobs=self.args["--jobs"],
for i, (train_indices, test_indices) in enumerate(tqdm.tqdm(splits)):
train_data = [gene for i in train_indices for gene in seqs[i]]
test_data = [gene for i in test_indices for gene in seqs[i]]
crf = ClusterCRF(
self.args["--feature-type"],
algorithm="lbfgs",
overlap=self.args["--overlap"],
c1=self.args["--c1"],
c2=self.args["--c2"],
)
)
# Add all the arguments given from CLI as table data
self.logger.info("Formatting results")
results["c1"] = self.args["--c1"]
results["c2"] = self.args["--c2"]
results["feature_type"] = self.args["--feature-type"]
results["e_filter"] = self.args["--e-filter"]
results["overlap"] = self.args["--overlap"]
results["weight"] = ",".join(map(str, self.args["--weight-cols"]))
results["feature"] = ",".join(self.args["--feature-cols"])
results["truncate"] = self.args["--truncate"]
results["input"] = os.path.basename(self.args["--input"])
results["cv_type"] = cv_type
# Write results
self.logger.info("Writing output to {!r}", self.args["--output"])
results.to_csv(self.args["--output"], sep="\t", index=False)
return 0
crf.fit(train_data, jobs=self.args["--jobs"], select=self.args["--select"])
new_genes.extend(crf.predict_probabilities(test_data, jobs=self.args["--jobs"]))
with open(self.args["--output"], "w") as out:
FeatureTable.from_genes(new_genes).dump(out)
......@@ -14,15 +14,15 @@ import signal
from typing import Union
import numpy
import pandas
from Bio import SeqIO
from ._base import Command
from .._utils import guess_sequences_format
from ...crf import ClusterCRF
from ...hmmer import HMMER, HMM, embedded_hmms
from ...types import TypeClassifier
from ...model import FeatureTable, ClusterTable
from ...orf import PyrodigalFinder
from ...types import TypeClassifier
from ...refine import ClusterRefiner
......@@ -62,7 +62,7 @@ class Run(Command): # noqa: D101
(antismash or gecco). [default: gecco]
Parameters - Debug:
--model <model.crf> the path to an alternative CRF model
--model <directory> the path to an alternative CRF model
to use (obtained with `gecco train`).
"""
......@@ -121,7 +121,7 @@ class Run(Command): # noqa: D101
self.logger.info("Running domain annotation")
# Run all HMMs over ORFs to annotate with protein domains
def annotate(hmm: HMM) -> "pandas.DataFrame":
def annotate(hmm: HMM) -> None:
self.logger.debug(
"Starting annotation with HMM {} v{}", hmm.id, hmm.version
)
......@@ -161,12 +161,10 @@ class Run(Command): # noqa: D101
self.logger.debug("Predicting BGC probabilities")
genes = crf.predict_probabilities(genes)
self.logger.debug("Extracting feature table")
feats_df = pandas.concat([g.to_feature_table() for g in genes], sort=False)
pred_out = os.path.join(out_dir, f"{base}.features.tsv")
self.logger.debug("Writing feature table to {!r}", pred_out)
feats_df.to_csv(pred_out, sep="\t", index=False)
with open(pred_out, "w") as f:
FeatureTable.from_genes(genes).dump(f)
# --- REFINE ---------------------------------------------------------
self.logger.info("Extracting gene clusters from prediction")
......@@ -196,8 +194,8 @@ class Run(Command): # noqa: D101
# Write predicted cluster coordinates to file
cluster_out = os.path.join(out_dir, f"{base}.clusters.tsv")
self.logger.debug("Writing cluster coordinates to {!r}", cluster_out)
table = pandas.concat([c.to_cluster_table() for c in clusters])
table.to_csv(cluster_out, sep="\t", index=False)
with open(cluster_out, "w") as out:
ClusterTable.from_clusters(clusters).dump(out)
# Write predicted cluster sequences to file
for cluster in clusters:
......
......@@ -7,19 +7,19 @@ import itertools
import logging
import multiprocessing.pool
import os
import operator
import pickle
import random
import typing
import numpy
import pandas
import scipy.sparse
import tqdm
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from ._base import Command
from ...model import Domain, Gene, Protein, Strand
from ...model import FeatureTable, ClusterTable, Cluster
from ...crf import ClusterCRF
from ...refine import ClusterRefiner
from ...hmmer import HMMER
......@@ -33,12 +33,14 @@ class Train(Command): # noqa: D101
Usage:
gecco train (-h | --help)
gecco train -i <data> [-w <col>]... [--feature-cols <col>]...
[--sort-cols <col>]... [--strat-cols <col>]... [options]
gecco train --features <table> --clusters <table> [options]
Arguments:
-i <data>, --input <data> a domain annotation table with regions
labeled as BGCs and non-BGCs.
-f <data>, --features <table> a domain annotation table, used to
train the CRF model.
-c <data>, --clusters <table> a cluster annotation table, used to
extract the domain composition for
the type classifier.
Parameters:
-o <out>, --output-dir <out> the directory to use for the model
......@@ -68,28 +70,6 @@ class Train(Command): # noqa: D101
--select <N> fraction of most significant features
to select from the training data.
Parameters - Column Names:
-y <col>, --y-col <col> column with class label. [default: BGC]
-w <col>, --weight-cols <col> columns with local weights on features.
[default: rev_i_Evalue]
-f <col>, --feature-cols <col> column to be used as features.
[default: domain]
-s <col>, --split-col <col> column to be used for splitting into
samples, i.e different sequences
[default: sequence_id]
-g <col>, --group-col <col> column to be used for grouping features
if `--feature-type` is *group*.
[default: protein_id]
--sort-cols <col> columns to be used for sorting the data
[default: genome_id start domain_start]
--strat-cols <col> columns to be used for stratifying the
samples (BGC types).
Parameters - Type Prediction:
--type-col <col> column containing BGC types to use for
domain composition. [default: BGC_type]
--id-col <col> column containing BGC id to use for
BGC labelling. [default: BGC_id]
"""
def _check(self) -> typing.Optional[int]:
......@@ -98,10 +78,10 @@ class Train(Command): # noqa: D101
return retcode
# Check the input exists
input_ = self.args["--input"]
if not os.path.exists(input_):
self.logger.error("could not locate input file: {!r}", input_)
return 1
for input_ in self.args["--features"], self.args["--clusters"]:
if not os.path.exists(input_):
self.logger.error("could not locate input file: {!r}", input_)
return 1
# Check the `--feature-type`
type_ = self.args["--feature-type"]
......@@ -133,115 +113,84 @@ class Train(Command): # noqa: D101
# --- LOADING AND PREPROCESSING --------------------------------------
# Load the table
self.logger.info("Loading the data")
feats_df = pandas.read_csv(self.args["--input"], sep="\t", encoding="utf-8")
self.logger.debug(
"Filtering results with e-value under {}", self.args["--e-filter"]
)
feats_df = feats_df[feats_df["i_Evalue"] < self.args["--e-filter"]]
with open(self.args["--features"]) as in_:
features = FeatureTable.load(in_)
# Computing reverse i_Evalue
self.logger.debug("Computing reverse indepenent e-value")
feats_df = feats_df.assign(rev_i_Evalue=1 - feats_df["i_Evalue"])
# Sorting data
self.logger.debug("Sorting data by gene and domain coordinates")
feats_df.sort_values(
by=[self.args["--split-col"], "start", "end", "domain_start"], inplace=True
)
# Grouping column
self.logger.debug(
"Splitting feature table using column {!r}", self.args["--split-col"]
)
training_data = [s for _, s in feats_df.groupby(self.args["--split-col"])]
if not self.args["--no-shuffle"]:
self.logger.debug("Shuffling splits randomly")
random.shuffle(training_data)
# Converting table to genes and sort by location
genes = sorted(features.to_genes(), key=operator.attrgetter("source.id", "start", "end"))
for gene in genes:
gene.protein.domains.sort(key=operator.attrgetter("start", "end"))
del features
# --- MODEL FITTING --------------------------------------------------
self.logger.info("Fitting the CRF model to the training data")
crf = ClusterCRF(
label_column=self.args["--y-col"],
feature_columns=self.args["--feature-cols"],
weight_columns=self.args["--weight-cols"],
group_column=self.args["--group-col"],
feature_type=self.args["--feature-type"],
overlap=self.args["--overlap"],
self.args["--feature-type"],
algorithm="lbfgs",
overlap=self.args["--overlap"],
c1=self.args["--c1"],
c2=self.args["--c2"],
)
self.logger.info("Fitting the CRF model to the training data")
crf.fit(
data=training_data,
trunc=self.args["--truncate"],
select=self.args["--select"],
)
crf.fit(genes, select=self.args["--select"], shuffle=not self.args["--no-shuffle"])
# --- MODEL SAVING ---------------------------------------------------
os.makedirs(self.args["--output-dir"], exist_ok=True)
model_out = os.path.join(self.args["--output-dir"], "model.pkl")
self.logger.info("Writing the model to {!r}", model_out)
with open(model_out, "wb") as f:
pickle.dump(crf, f, protocol=4)
with open(model_out, "wb") as out:
pickle.dump(crf, out, protocol=4)
self.logger.debug("Computing and saving model checksum")
hasher = hashlib.md5()
with open(model_out, "rb") as f:
hasher.update(f.read()) # FIXME: iterate on file blocks
with open(f"{model_out}.md5", "w") as f:
f.write(hasher.hexdigest())
self.logger.info("Writing transitions and state weights")
crf.save_weights(self.args["--output-dir"])
# # --- DOMAIN COMPOSITION ----------------------------------------------
with open(model_out, "rb") as out:
hasher.update(out.read()) # FIXME: iterate on file blocks
with open(f"{model_out}.md5", "w") as out_hash:
out_hash.write(hasher.hexdigest())
self.logger.info("Writing transitions weights")
with open(os.path.join(self.args["--output-dir"], "model.trans.tsv"), "w") as f:
writer = csv.writer(f, dialect="excel-tab")
writer.writerow(["from", "to", "weight"])
for labels, weight in crf.model.transition_features_.items():
writer.writerow([*labels, weight])
self.logger.info("Writing state weights")
with open(os.path.join(self.args["--output-dir"], "model.state.tsv"), "w") as f:
writer = csv.writer(f, dialect="excel-tab")
writer.writerow(["attr", "label", "weight"])
for attrs, weight in crf.model.state_features_.items():
writer.writerow([*attrs, weight])
# --- DOMAIN COMPOSITION ---------------------------------------------
self.logger.info("Extracting domain composition for type classifier")
self.logger.debug("Loading cluster table")
with open(self.args["--clusters"]) as f:
clusters = [
Cluster(c.bgc_id, [g for g in genes if g.id in c.proteins], c.type)
for c in sorted(ClusterTable.load(f), key=operator.attrgetter("bgc_id"))
]
self.logger.debug("Finding the array of possible domains")
if crf.significant_features:
all_possible = sorted(
{d for domains in crf.significant_features.values() for d in domains}
)
if crf.significant_features is not None:
all_possible = sorted(crf.significant_features)
else:
all_possible = sorted(
{t.domain for d in training_data for t in d[d["BGC"] == 1].itertuples()}
)
types = [d[self.args["--type-col"]].values[0] for d in training_data]
ids = [d[self.args["--id-col"]].values[0] for d in training_data]
all_possible = sorted({d.name for g in genes for d in g.protein.domains})
self.logger.debug("Saving training matrix labels for BGC type classifier")
doms_out = os.path.join(self.args["--output-dir"], "domains.tsv")
pandas.Series(all_possible).to_csv(
doms_out, sep="\t", index=False, header=False
)
types_out = os.path.join(self.args["--output-dir"], "types.tsv")
df = pandas.DataFrame(dict(ids=ids, types=types))
df.to_csv(types_out, sep="\t", index=False, header=False)
with open(os.path.join(self.args["--output-dir"], "domains.tsv"), "w") as out:
out.writelines([f"{domain}\n" for domain in all_possible])
with open(os.path.join(self.args["--output-dir"], "types.tsv"), "w") as out:
writer = csv.writer(out, dialect="excel-tab")
for cluster in clusters:
types = ";".join(ty.name for ty in cluster.type.unpack())
writer.writerow([cluster.id, types])
self.logger.debug("Building new domain composition matrix")
if self.stream.isatty() and self.logger.level != 0:
pbar = tqdm.tqdm(total=len(training_data), leave=False)
else:
pbar = None
def domain_composition(table: pandas.DataFrame) -> numpy.ndarray:
is_bgc = table[self.args["--y-col"]].array == 1
names = table[self.args["--feature-cols"][0]].array[is_bgc]
unique_names = set(names)
weights = table[self.args["--weight-cols"][0]].array[is_bgc]
composition = numpy.zeros(len(all_possible))
for i, domain in enumerate(all_possible):
if domain in unique_names:
composition[i] = numpy.sum(weights[names == domain])
if pbar is not None:
pbar.update(1)
return composition / (composition.sum() or 1) # type: ignore
with multiprocessing.pool.ThreadPool(self.args["--jobs"]) as pool:
new_comp = numpy.array(pool.map(domain_composition, training_data))
if pbar is not None:
pbar.close()
comp = numpy.array([c.domain_composition(all_possible) for c in clusters])
comp_out = os.path.join(self.args["--output-dir"], "compositions.npz")
self.logger.debug("Saving new domain composition matrix to {!r}", comp_out)
scipy.sparse.save_npz(comp_out, scipy.sparse.coo_matrix(new_comp))
scipy.sparse.save_npz(comp_out, scipy.sparse.coo_matrix(comp))
return 0
This diff is collapsed.
This diff is collapsed.