Commits (13)
......@@ -5,7 +5,23 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.4.3...master
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.4.4...master
## [v0.4.4] - 2020-09-30
[v0.4.4]: https://git.embl.de/grp-zeller/GECCO/compare/v0.4.3...v0.4.4
### Added
- `gecco cv loto` command to run LOTO cross-validation using BGC types
for stratification.
- `header` keyword argument to `FeatureTable.dump` and `ClusterTable.dump`
to write the table without the column header allowing to append to an
existing table.
- `__getitem__` implementation for `FeatureTable` and `ClusterTable`
that returns a single row or a sub-table from a table.
### Fixed
- `gecco cv` command now writes results iteratively instead of holding
the tables for every fold in memory.
### Changed
- Bumped `pandas` training dependency to `v1.0`.
## [v0.4.3] - 2020-09-07
[v0.4.3]: https://git.embl.de/grp-zeller/GECCO/compare/v0.4.2...v0.4.3
......
......@@ -14,8 +14,9 @@ import tqdm
import sklearn.model_selection
from ._base import Command
from ...model import FeatureTable
from ...model import ClusterTable, FeatureTable
from ...crf import ClusterCRF
from ...crf.cv import LeaveOneGroupOut
class Cv(Command): # noqa: D101
......@@ -26,11 +27,15 @@ class Cv(Command): # noqa: D101
Usage:
gecco cv (-h | --help)
gecco cv kfold -i <data> [-w <col>]... [-f <col>]... [options]
gecco cv kfold -f <table> [-c <data>] [options]
gecco cv loto -f <table> -c <data> [options]
Arguments:
-i <data>, --input <data> a domain annotation table with regions
-f <data>, --features <table> a domain annotation table, used to
labeled as BGCs and non-BGCs.
-c <data>, --clusters <table> a cluster annotation table, use to
stratify clusters by type in LOTO
mode.
Parameters:
-o <out>, --output <out> the name of the output cross-validation
......@@ -51,8 +56,6 @@ class Cv(Command): # noqa: D101
--feature-type <type> how features should be extracted
(single, overlap, or group).
[default: group]
--truncate <N> the maximum number of rows to use from
the training set.
--overlap <N> how much overlap to consider if
features overlap. [default: 2]
--splits <N> number of folds for cross-validation
......@@ -61,20 +64,6 @@ class Cv(Command): # noqa: D101
to select from the training data.
--shuffle enable shuffling of stratified rows.
Parameters - Column Names:
-y <col>, --y-col <col> column with class label. [default: BGC]
-w <col>, --weight-cols <col> columns with local weights on features.
[default: rev_i_Evalue]
-f <col>, --feature-cols <col> column to be used as features.
[default: domain]
-s <col>, --split-col <col> column to be used for splitting into
samples, i.e different sequences
[default: sequence_id]
-g <col>, --group-col <col> column to be used for grouping features
if `--feature-type` is *group*.
[default: protein_id]
--strat-col <col> column to be used for stratifying the
samples. [default: BGC_type]
"""
def _check(self) -> typing.Optional[int]:
......@@ -82,11 +71,11 @@ class Cv(Command): # noqa: D101
if retcode is not None:
return retcode
# Check the input exists
input_ = self.args["--input"]
if not os.path.exists(input_):
self.logger.error("could not locate input file: {!r}", input_)
return 1
# Check the inputs exist
for input_ in filter(None, (self.args["--features"], self.args["--clusters"])):
if not os.path.exists(input_):
self.logger.error("could not locate input file: {!r}", input_)
return 1
# Check the `--feature-type`
type_ = self.args["--feature-type"]
......@@ -95,8 +84,6 @@ class Cv(Command): # noqa: D101
return 1
# Check value of numeric arguments
if self.args["--truncate"] is not None:
self.args["--truncate"] = int(self.args["--truncate"])
self.args["--overlap"] = int(self.args["--overlap"])
self.args["--c1"] = float(self.args["--c1"])
self.args["--c2"] = float(self.args["--c2"])
......@@ -116,14 +103,37 @@ class Cv(Command): # noqa: D101
return None
def __call__(self) -> int: # noqa: D102
seqs = self._load_sequences()
if self.args["loto"]:
splits = self._loto_splits(seqs)
else:
splits = self._kfold_splits(seqs)
self.logger.info("Performing cross-validation")
for i, (train_indices, test_indices) in enumerate(tqdm.tqdm(splits)):
train_data = [gene for i in train_indices for gene in seqs[i]]
test_data = [gene for i in test_indices for gene in seqs[i]]
crf = ClusterCRF(
self.args["--feature-type"],
algorithm="lbfgs",
overlap=self.args["--overlap"],
c1=self.args["--c1"],
c2=self.args["--c2"],
)
crf.fit(train_data, jobs=self.args["--jobs"], select=self.args["--select"])
new_genes = crf.predict_probabilities(test_data, jobs=self.args["--jobs"])
with open(self.args["--output"], "a" if i else "w") as out:
FeatureTable.from_genes(new_genes).dump(out, header=i==0)
# --- LOADING AND PREPROCESSING --------------------------------------
# Load the table
self.logger.info("Loading the data")
with open(self.args["--input"]) as in_:
return 0
def _load_sequences(self):
self.logger.info("Loading the feature table")
with open(self.args["--features"]) as in_:
table = FeatureTable.load(in_)
# Converting table to genes and sort by location
self.logger.info("Converting data to genes")
gene_count = len(set(table.protein_id))
genes = list(tqdm.tqdm(table.to_genes(), total=gene_count))
......@@ -134,32 +144,29 @@ class Cv(Command): # noqa: D101
for gene in genes:
gene.protein.domains.sort(key=operator.attrgetter("start", "end"))
# group by sequence
self.logger.info("Grouping genes by source sequence")
groups = itertools.groupby(genes, key=operator.attrgetter("source.id"))
seqs = [sorted(group, key=operator.attrgetter("start")) for _, group in groups]
# shuffle if required
if self.args["--shuffle"]:
self.logger.info("Shuffling training data sequences")
random.shuffle(seqs)
# --- CROSS-VALIDATION ------------------------------------------------
k = 10
splits = list(sklearn.model_selection.KFold(k).split(seqs))
new_genes = []
return seqs
self.logger.info("Performing cross-validation")
for i, (train_indices, test_indices) in enumerate(tqdm.tqdm(splits)):
train_data = [gene for i in train_indices for gene in seqs[i]]
test_data = [gene for i in test_indices for gene in seqs[i]]
crf = ClusterCRF(
self.args["--feature-type"],
algorithm="lbfgs",
overlap=self.args["--overlap"],
c1=self.args["--c1"],
c2=self.args["--c2"],
)
crf.fit(train_data, jobs=self.args["--jobs"], select=self.args["--select"])
new_genes.extend(crf.predict_probabilities(test_data, jobs=self.args["--jobs"]))
def _loto_splits(self, seqs):
self.logger.info("Loading the clusters table")
with open(self.args["--clusters"]) as in_:
table = ClusterTable.load(in_)
index = { protein: row.type for row in table for protein in row.proteins }
groups = []
for cluster in seqs:
ty = next(index[seq.id] for seq in cluster if seq.id in index)
groups.append(ty.unpack())
return list(LeaveOneGroupOut().split(seqs, groups=groups))
with open(self.args["--output"], "w") as out:
FeatureTable.from_genes(new_genes).dump(out)
def _kfold_splits(self, seqs):
k = self.args["--splits"]
return list(sklearn.model_selection.KFold(k).split(seqs))
......@@ -117,6 +117,7 @@ class Train(Command): # noqa: D101
features = FeatureTable.load(in_)
# Converting table to genes and sort by location
self.logger.info("Sorting genes by location")
genes = sorted(features.to_genes(), key=operator.attrgetter("source.id", "start", "end"))
for gene in genes:
gene.protein.domains.sort(key=operator.attrgetter("start", "end"))
......
......@@ -12,29 +12,30 @@ import sklearn.model_selection
class LeaveOneGroupOut(sklearn.model_selection.LeaveOneGroupOut):
"""A `~sklearn.model_selection.LeaveOneGroupOut` supporting multiple labels.
If a sample has multiple class labels, it will be left out of the training
data each time the fold corresponds to one of its class labels::
If a sample has multiple class labels, it will be excluded from both
training and testing data when one of its labels corresponds to the
fold::
>>> loto = LeaveOneGroupOut()
>>> groups = numpy.array([ ["a"], ["b"], ["c"], ["a", "b"] ])
>>> for i, (trn, tst) in enumerate(loto.split(range(4), groups=groups)):
... print("-"*20)
... print(" FOLD", i+1)
... print("TRAIN", trn, list(groups[trn]))
... print(" TEST", tst, list(groups[tst]))
... print("TRAIN", f"{str(trn):<7}", list(groups[trn]))
... print(" TEST", f"{str(tst):<7}", list(groups[tst]))
...
--------------------
FOLD 1
TRAIN [1 2] [['b'], ['c']]
TEST [0 3] [['a'], ['a', 'b']]
TRAIN [1 2] [['b'], ['c']]
TEST [0] [['a']]
--------------------
FOLD 2
TRAIN [0 2] [['a'], ['c']]
TEST [1 3] [['b'], ['a', 'b']]
TRAIN [0 2] [['a'], ['c']]
TEST [1] [['b']]
--------------------
FOLD 3
TRAIN [0 1 3] [['a'], ['b'], ['a', 'b']]
TEST [2] [['c']]
TEST [2] [['c']]
"""
......@@ -75,18 +76,16 @@ class LeaveOneGroupOut(sklearn.model_selection.LeaveOneGroupOut):
labels = {label for labels in groups for label in labels}
return len(labels)
def _iter_test_masks(
self, X: object, y: object, groups: Iterable[Iterable[object]]
) -> Iterator["numpy.ndarray"]:
def split(self, X, y=None, groups=None): # noqa: D102
if groups is None:
raise ValueError("The 'groups' parameter should not be None.")
# We collect the groups to avoid side-effects during iteration
group_sets: List[Set[object]] = list(map(set, groups)) # type: ignore
raise ValueError("The 'groups' parameter should not be None")
# collect groups
group_sets: List[Set[object]] = list(map(set, groups))
unique_groups = {label for labels in group_sets for label in labels}
if len(unique_groups) <= 1:
raise ValueError(
"The groups parameter contains fewer than 2 unique groups "
f"({unique_groups}). LeaveOneGroupOut expects at least 2."
)
for i in sorted(unique_groups):
yield numpy.array([i in group for group in groups])
#
indices = numpy.arange(len(X))
for ty in sorted(unique_groups):
test_mask = numpy.array([list(group) == [ty] for group in groups])
train_mask = numpy.array([ty not in group for group in groups])
yield indices[train_mask], indices[test_mask]
......@@ -362,7 +362,15 @@ class _UnknownSeq(Seq):
def __init__(self) -> None:
super().__init__(data="")
def __getitem__(self, index):
@typing.overload
def __getitem__(self, index: int) -> str:
pass
@typing.overload
def __getitem__(self, index: slice) -> Seq:
pass
def __getitem__(self, index: Union[slice, int]) -> Union[str, Seq]:
if isinstance(index, slice):
return Seq("N" * ((index.stop - index.start) // (index.step or 1)) )
return "N"
......@@ -399,7 +407,7 @@ class FeatureTable(Dumpable, Sized):
i_evalue: float
domain_start: int
domain_end: int
bgc_probability: float
bgc_probability: Optional[float]
@classmethod
def from_genes(cls, genes: Iterable[Gene]) -> "FeatureTable":
......@@ -439,6 +447,9 @@ class FeatureTable(Dumpable, Sized):
gene.protein.domains.append(domain)
yield gene
def __bool__(self) -> bool: # noqa: D105
return len(self) != 0
def __len__(self) -> int: # noqa: D105
return len(self.sequence_id)
......@@ -448,7 +459,22 @@ class FeatureTable(Dumpable, Sized):
row = { c: getter(self)[i] for c, getter in columns.items() }
yield self.Row(**row)
def dump(self, fh: TextIO, dialect: str = "excel-tab") -> None:
@typing.overload
def __getitem__(self, item: slice) -> "FeatureTable": # noqa: D105
pass
@typing.overload
def __getitem__(self, item: int) -> Row: # noqa: D105
pass
def __getitem__(self, item: Union[slice, int]) -> Union["FeatureTable", "Row"]: # noqa: D105
columns = [getattr(self, col)[item] for col in self.__annotations__]
if isinstance(item, slice):
return type(self)(*columns)
else:
return self.Row(*columns)
def dump(self, fh: TextIO, dialect: str = "excel-tab", header: bool = True) -> None:
"""Write the feature table in CSV format to the given file.
Arguments:
......@@ -456,13 +482,17 @@ class FeatureTable(Dumpable, Sized):
to write the feature table to.
dialect (`str`): The CSV dialect to use. See `csv.list_dialects`
for allowed values.
header (`bool`): Whether or not to include the column header when
writing the table (useful for appending to an existing table).
Defaults to `True`.
"""
writer = csv.writer(fh, dialect=dialect)
header = list(self.__annotations__)
writer.writerow(header)
columns = list(self.__annotations__)
if header:
writer.writerow(columns)
for row in self:
writer.writerow([ getattr(row, col) for col in header ])
writer.writerow([ getattr(row, col) for col in columns ])
@classmethod
def load(cls, fh: TextIO, dialect: str = "excel-tab") -> "FeatureTable":
......@@ -496,6 +526,8 @@ class FeatureTable(Dumpable, Sized):
for row in reader:
for index, _, append, ty in columns:
append(None if index is None else ty(row[index]))
return table
......@@ -504,6 +536,25 @@ class ClusterTable(Dumpable, Sized):
"""A table storing condensed information from several clusters.
"""
sequence_id: List[str] = field(default_factory = list)
bgc_id: List[str] = field(default_factory = list)
start: List[int] = field(default_factory = lambda: array("l")) # type: ignore
end: List[int] = field(default_factory = lambda: array("l")) # type: ignore
average_p: List[float] = field(default_factory = lambda: array("d")) # type: ignore
max_p: List[float] = field(default_factory = lambda: array("d")) # type: ignore
type: List[ProductType] = field(default_factory = list)
alkaloid_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
polyketide_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
ripp_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
saccharide_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
terpene_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
nrp_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
other_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
proteins: List[List[str]] = field(default_factory = list)
domains: List[List[str]] = field(default_factory = list)
class Row(NamedTuple):
"""A single row in a cluster table.
"""
......@@ -525,25 +576,6 @@ class ClusterTable(Dumpable, Sized):
proteins: List[str]
domains: List[str]
sequence_id: List[str] = field(default_factory = list)
bgc_id: List[str] = field(default_factory = list)
start: List[int] = field(default_factory = lambda: array("l")) # type: ignore
end: List[int] = field(default_factory = lambda: array("l")) # type: ignore
average_p: List[float] = field(default_factory = lambda: array("d")) # type: ignore
max_p: List[float] = field(default_factory = lambda: array("d")) # type: ignore
type: List[ProductType] = field(default_factory = list)
alkaloid_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
polyketide_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
ripp_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
saccharide_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
terpene_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
nrp_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
other_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
proteins: List[List[str]] = field(default_factory = list)
domains: List[List[str]] = field(default_factory = list)
@classmethod
def from_clusters(cls, clusters: Iterable[Cluster]) -> "ClusterTable":
"""Create a new cluster table from an iterable of clusters.
......@@ -580,7 +612,25 @@ class ClusterTable(Dumpable, Sized):
row = { c: getter(self)[i] for c, getter in columns.items() }
yield self.Row(**row)
def dump(self, fh: TextIO, dialect: str = "excel-tab") -> None:
def __bool__(self) -> bool: # noqa: D105
return len(self) != 0
@typing.overload
def __getitem__(self, item: slice) -> "ClusterTable": # noqa: D105
pass
@typing.overload
def __getitem__(self, item: int) -> Row: # noqa: D105
pass
def __getitem__(self, item: Union[int, slice]) -> Union["ClusterTable", Row]: # noqa: D105
columns = [getattr(self, col)[item] for col in self.__annotations__]
if isinstance(item, slice):
return type(self)(*columns)
else:
return self.Row(*columns)
def dump(self, fh: TextIO, dialect: str = "excel-tab", header: bool = True) -> None:
"""Write the cluster table in CSV format to the given file.
Arguments:
......@@ -588,14 +638,19 @@ class ClusterTable(Dumpable, Sized):
to write the cluster table to.
dialect (`str`): The CSV dialect to use. See `csv.list_dialects`
for allowed values.
header (`bool`): Whether or not to include the column header when
writing the table (useful for appending to an existing table).
Defaults to `True`.
"""
writer = csv.writer(fh, dialect=dialect)
header = list(self.__annotations__)
writer.writerow(header)
columns = list(self.__annotations__)
row = []
if header:
writer.writerow(columns)
for i in range(len(self)):
row = []
for col in header:
row.clear()
for col in columns:
value = getattr(self, col)[i]
if col == "type":
types = value.unpack() or [ProductType.Unknown]
......
......@@ -48,7 +48,7 @@ install_requires =
train =
fisher ~=0.1.9
statsmodels ~=0.11.1
pandas ~=0.25.3
pandas ~=1.0
[options.packages.find]
include = gecco
......