Commits (6)
......@@ -5,10 +5,21 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.9.1-alpha2...master
[Unreleased]: https://git.embl.de/grp-zeller/GECCO/compare/v0.9.1-alpha3...master
## [v0.9.1-alpha3] - 2022-03-23
[v0.9.1-alpha3]: https://git.embl.de/grp-zeller/GECCO/compare/v0.9.1-alpha2...v0.9.1-alpha3
### Added
- `gecco.model.GeneTable` class to store gene coordinates independently of protein domains.
### Changed
- Refactored implementation of `load` and `dump` methods for `Table` classes into a dedicated base class.
- `gecco run` and `gecco annotate` now output a gene table in addition to the feature and cluster tables.
- `gecco train` expects a gene table instead of a GFF file for the gene coordinates.
## [v0.9.1-alpha2] - 2022-03-23
[v0.9.1-alpha1]: https://git.embl.de/grp-zeller/GECCO/compare/v0.9.1-alpha1...v0.9.1-alpha2
[v0.9.1-alpha2]: https://git.embl.de/grp-zeller/GECCO/compare/v0.9.1-alpha1...v0.9.1-alpha2
### Fixed
- `TypeClassifier.trained` not being able to read unknown types from type tables.
......
......@@ -10,4 +10,4 @@ See Also:
__author__ = "Martin Larralde"
__license__ = "GPLv3"
__version__ = "0.9.1-alpha2"
__version__ = "0.9.1-alpha3"
import abc
import csv
import errno
import io
import itertools
import operator
import subprocess
import typing
from typing import Iterable, Optional, Type, TextIO
from collections.abc import Sized
from typing import Iterable, Iterator, List, NamedTuple, Optional, Type, TextIO, Union
from subprocess import DEVNULL
from ._meta import classproperty
from ._meta import classproperty, requires
_SELF = typing.TypeVar("_SELF")
def _parse_str(value: str) -> str:
return value
def _parse_int(value: str) -> int:
return int(value)
def _parse_float(value: str) -> float:
return float(value)
def _parse_optional_float(value: str) -> typing.Optional[float]:
return float(value) if value else None
def _parse_list_str(value: str) -> typing.List[str]:
return value.split(";")
def _parse_optional_list_str(value: str) -> typing.Optional[typing.List[str]]:
return None if not value else _parse_list_str(value)
def _format_int(value: int) -> str:
return str(value)
def _format_str(value: str) -> str:
return value
def _format_float(value: float) -> str:
return str(value)
def _format_optional_float(value: typing.Optional[float]) -> str:
return "" if value is None else str(value)
def _format_list_str(value: typing.List[str]) -> str:
return ";".join(value)
def _format_optional_list_str(value: typing.Optional[typing.List[str]]) -> str:
return "" if value is None else _format_list_str(value)
class Dumpable(metaclass=abc.ABCMeta):
......@@ -21,3 +65,153 @@ class Dumpable(metaclass=abc.ABCMeta):
s = io.StringIO()
self.dump(s)
return s.getvalue()
class Loadable(metaclass=abc.ABCMeta):
"""A metaclass for objects that can be loaded from a text file.
"""
@classmethod
@abc.abstractmethod
def load(cls: typing.Type[_SELF], fh: TextIO) -> _SELF:
raise NotImplementedError
@classmethod
def loads(cls: typing.Type[_SELF], s: str) -> _SELF:
return self.load(io.StringIO(s))
class Table(Dumpable, Loadable, Sized):
"""A metaclass for objects that
"""
Row: typing.Type[typing.NamedTuple]
def __bool__(self) -> bool: # noqa: D105
return len(self) != 0
def __iadd__(self: _SELF, rhs: _SELF) -> _SELF: # noqa: D105
if not isinstance(rhs, type(self)):
return NotImplemented
for col in self.__annotations__:
getattr(self, col).extend(getattr(rhs, col))
return self
@typing.overload
def __getitem__(self: _SELF, item: slice) -> _SELF: # noqa: D105
pass
@typing.overload
def __getitem__(self, item: int) -> "Row": # noqa: D105
pass
def __getitem__(self: _SELF, item: Union[slice, int]) -> Union[_SELF, "Row"]: # noqa: D105
columns = [getattr(self, col)[item] for col in self.__annotations__]
if isinstance(item, slice):
return type(self)(*columns)
else:
return self.Row(*columns)
def __iter__(self) -> Iterator["Row"]: # noqa: D105
columns = { c: operator.attrgetter(c) for c in self.__annotations__ }
for i in range(len(self)):
row = { c: getter(self)[i] for c, getter in columns.items() }
yield self.Row(**row)
@classmethod
def _optional_columns(cls) -> typing.Set[str]:
optional_columns = set()
for name, ty in cls.Row.__annotations__.items():
if ty == Optional[int] or ty == Optional[float] or ty == Optional[List[str]]:
optional_columns.add(name)
return optional_columns
_FORMAT_FIELD: dict = {
str: _format_str,
int: _format_int,
float: _format_float,
typing.Optional[float]: _format_optional_float,
typing.List[str]: _format_list_str,
typing.Optional[typing.List[str]]: _format_optional_list_str,
}
def dump(self, fh: TextIO, dialect: str = "excel-tab", header: bool = True) -> None:
"""Write the table in CSV format to the given file.
Arguments:
fh (file-like `object`): A writable file-handle opened in text mode
to write the feature table to.
dialect (`str`): The CSV dialect to use. See `csv.list_dialects`
for allowed values.
header (`bool`): Whether or not to include the column header when
writing the table (useful for appending to an existing table).
Defaults to `True`.
"""
writer = csv.writer(fh, dialect=dialect)
column_names = list(self.__annotations__)
optional = self._optional_columns()
# do not write optional columns if they are completely empty
for name in optional:
if all(x is None for x in getattr(self, name)):
column_names.remove(name)
# write header if desired
if header:
writer.writerow(column_names)
# write each row
columns = [getattr(self, name) for name in column_names]
formatters = [self._FORMAT_FIELD[self.Row.__annotations__[name]] for name in column_names]
for i in range(len(self)):
writer.writerow([format(col[i]) for col,format in zip(columns, formatters)])
_PARSE_FIELD: dict = {
str: _parse_str,
int: _parse_int,
float: _parse_float,
typing.Optional[float]: _parse_optional_float,
typing.List[str]: _parse_list_str,
typing.Optional[typing.List[str]]: _parse_optional_list_str,
}
@classmethod
def load(cls: typing.Type[_SELF], fh: TextIO, dialect: str = "excel-tab") -> _SELF:
"""Load a table in CSV format from a file handle in text mode.
"""
table = cls()
reader = csv.reader(fh, dialect=dialect)
header = next(reader)
# get the name of each column and check which columns are optional
columns = [getattr(table, col) for col in header]
optional = cls._optional_columns()
parsers = [cls._PARSE_FIELD[table.Row.__annotations__[col]] for col in header]
# check that if a column is missing, it is one of the optional values
missing = set(cls.__annotations__).difference(header)
missing_required = missing.difference(optional)
if missing_required:
raise ValueError("table is missing columns: {}".format(", ".join(missing_required)))
# extract elements from the CSV rows
for row in reader:
for col, value, parse in itertools.zip_longest(columns, row, parsers):
col.append(parse(value))
for col in missing:
getattr(table, col).extend(None for _ in range(len(table)))
return table
@requires("pandas")
def to_dataframe(self) -> "pandas.DataFrame":
"""Convert the table to a `~pandas.DataFrame`.
Raises:
ImportError: if the `pandas` module could not be imported.
"""
frame = pandas.DataFrame() # type: ignore
for column in self.__annotations__:
frame[column] = getattr(self, column)
return frame
......@@ -54,8 +54,8 @@ class Annotate(Command): # noqa: D101
the available CPUs. [default: 0]
Parameters - Output:
-o <out>, --output <out> the file where to write the feature
table. [default: features.tsv]
-o <out>, --output-dir <out> the directory in which to write the
output files. [default: .]
Parameters - Gene Calling:
......@@ -97,7 +97,7 @@ class Annotate(Command): # noqa: D101
self.format = self._check_flag("--format", optional=True)
self.genome = self._check_flag("--genome")
self.hmm = self._check_flag("--hmm", optional=True)
self.output = self._check_flag("--output")
self.output_dir = self._check_flag("--output-dir")
self.mask = self._check_flag("--mask", bool)
except InvalidArgument:
raise CommandExit(1)
......@@ -123,6 +123,27 @@ class Annotate(Command): # noqa: D101
# ---
_OUTPUT_FILES = ["features.tsv", "genes.tsv"]
def _make_output_directory(self, extensions: List[str]) -> None:
# Make output directory
self.info("Using", "output folder", repr(self.output_dir), level=1)
try:
os.makedirs(self.output_dir, exist_ok=True)
except OSError as err:
self.error("Could not create output directory: {}", err)
raise CommandExit(err.errno) from err
# Check if output files already exist
base, _ = os.path.splitext(os.path.basename(self.genome))
# output_exts = ["features.tsv", "genes.tsv", "clusters.tsv"]
# if self.antismash_sideload:
# output_exts.append("sideload.json")
for ext in extensions:
if os.path.isfile(os.path.join(self.output_dir, f"{base}.{ext}")):
self.warn("Output folder contains files that will be overwritten")
break
def _load_sequences(self):
from Bio import SeqIO
......@@ -223,10 +244,21 @@ class Annotate(Command): # noqa: D101
def _write_feature_table(self, genes):
from ...model import FeatureTable
self.info("Writing", "feature table to", repr(self.output), level=1)
with open(self.output, "w") as f:
base, _ = os.path.splitext(os.path.basename(self.genome))
pred_out = os.path.join(self.output_dir, f"{base}.features.tsv")
self.info("Writing", "feature table to", repr(pred_out), level=1)
with open(pred_out, "w") as f:
FeatureTable.from_genes(genes).dump(f)
def _write_genes_table(self, genes):
from ...model import GeneTable
base, _ = os.path.splitext(os.path.basename(self.genome))
pred_out = os.path.join(self.output_dir, f"{base}.genes.tsv")
self.info("Writing", "gene table to", repr(pred_out), level=1)
with open(pred_out, "w") as f:
GeneTable.from_genes(genes).dump(f)
# ---
def execute(self, ctx: contextlib.ExitStack) -> int: # noqa: D102
......@@ -235,9 +267,13 @@ class Annotate(Command): # noqa: D101
self._check()
ctx.enter_context(self.progress)
ctx.enter_context(patch_showwarnings(self._showwarnings))
# attempt to create the output directory, checking it doesn't
# already contain output files (or raise a warning)
self._make_output_directory(extensions=["features.tsv", "genes.tsv"])
# load sequences and extract genes
sequences = self._load_sequences()
genes = self._extract_genes(sequences)
self._write_genes_table(genes)
if genes:
self.success("Found", "a total of", len(genes), "genes", level=1)
else:
......
......@@ -141,25 +141,6 @@ class Run(Annotate): # noqa: D101
# ---
def _make_output_directory(self) -> None:
# Make output directory
self.info("Using", "output folder", repr(self.output_dir), level=1)
try:
os.makedirs(self.output_dir, exist_ok=True)
except OSError as err:
self.error("Could not create output directory: {}", err)
raise CommandExit(err.errno) from err
# Check if output files already exist
base, _ = os.path.splitext(os.path.basename(self.genome))
output_exts = ["features.tsv", "clusters.tsv"]
if self.antismash_sideload:
output_exts.append("sideload.json")
for ext in output_exts:
if os.path.isfile(os.path.join(self.output_dir, f"{base}.{ext}")):
self.warn("Output folder contains files that will be overwritten")
break
def _load_model_domains(self) -> typing.Set[str]:
try:
if self.model is None:
......@@ -195,15 +176,6 @@ class Run(Annotate): # noqa: D101
cpus=self.jobs
))
def _write_feature_table(self, genes):
from ...model import FeatureTable
base, _ = os.path.splitext(os.path.basename(self.genome))
pred_out = os.path.join(self.output_dir, f"{base}.features.tsv")
self.info("Writing", "feature table to", repr(pred_out), level=1)
with open(pred_out, "w") as f:
FeatureTable.from_genes(genes).dump(f)
def _extract_clusters(self, genes):
from ...refine import ClusterRefiner
......@@ -326,8 +298,12 @@ class Run(Annotate): # noqa: D101
self._check()
ctx.enter_context(self.progress)
ctx.enter_context(patch_showwarnings(self._showwarnings))
# attempt to create the output directory
self._make_output_directory()
# attempt to create the output directory, checking it doesn't
# already contain output files (or raise a warning)
extensions = ["features.tsv", "genes.tsv", "clusters.tsv"]
if self.antismash_sideload:
extensions.append("sideload.json")
self._make_output_directory(extensions)
# load sequences and extract genes
sequences = self._load_sequences()
genes = self._extract_genes(sequences)
......@@ -342,6 +318,7 @@ class Run(Annotate): # noqa: D101
# annotate domains and predict probabilities
genes = self._annotate_domains(genes, whitelist=whitelist)
genes = self._predict_probabilities(genes)
self._write_genes_table(genes)
self._write_feature_table(genes)
# extract clusters from probability vector
clusters = self._extract_clusters(genes)
......
......@@ -42,7 +42,7 @@ class Train(Command): # noqa: D101
-c <data>, --clusters <table> a cluster annotation table, used to
extract the domain composition for
the type classifier.
-g <file>, --genes <file> a GFF file containing the
-g <file>, --genes <file> a gene table containing the
coordinates of the genes inside
the training sequence.
......@@ -186,26 +186,17 @@ class Train(Command): # noqa: D101
break
def _load_genes(self) -> Iterator["Gene"]:
from Bio.SeqRecord import SeqRecord
from ...model import Gene, Protein, Strand, _UnknownSeq
from ...model import GeneTable
try:
# get filesize and unit
input_size = os.stat(self.genes).st_size
total, scale, unit = ProgressReader.scale_size(input_size)
task = self.progress.add_task("Loading genes", total=total, unit=unit, precision=".1f")
#
self.info("Loading", "gene coordinates from file", repr(self.genes))
with ProgressReader(open(self.genes, "rb"), self.progress, task, scale) as gff_file:
for row in csv.reader(io.TextIOWrapper(gff_file), dialect="excel-tab"):
name, _, _, start, end, _, strand, *_ = row
yield Gene(
SeqRecord(_UnknownSeq(), id=name.rsplit("_", 1)[0]),
int(start),
int(end),
Strand.Coding if strand == "+" else Strand.Reverse,
Protein(name, _UnknownSeq()),
)
# load gene table
self.info("Loading", "genes table from file", repr(self.genes))
with ProgressReader(open(self.genes, "rb"), self.progress, task, scale) as genes_file:
yield from GeneTable.load(io.TextIOWrapper(genes_file)).to_genes()
except OSError as err:
self.error("Fail to parse genes coordinates: {}", err)
raise CommandExit(err.errno) from err
......
......@@ -24,8 +24,8 @@ from Bio.SeqFeature import SeqFeature, FeatureLocation, CompoundLocation, Refere
from Bio.SeqRecord import SeqRecord
from . import __version__
from ._base import Dumpable
from ._meta import requires, patch_locale
from ._base import Dumpable, Table
from ._meta import patch_locale
__all__ = [
......@@ -507,7 +507,7 @@ class _UnknownSeq(Seq):
@dataclass(frozen=True)
class FeatureTable(Dumpable, Sized):
class FeatureTable(Table):
"""A table storing condensed domain annotations from different genes.
"""
......@@ -582,120 +582,29 @@ class FeatureTable(Dumpable, Sized):
assert all(x.end == rows[0].end for x in rows)
source = SeqRecord(id=rows[0].sequence_id, seq=_UnknownSeq())
strand = Strand.Coding if rows[0].strand == "+" else Strand.Reverse
protein = Protein(rows[0].protein_id, seq=None)
protein = Protein(rows[0].protein_id, seq=_UnknownSeq)
gene = Gene(source, rows[0].start, rows[0].end, strand, protein)
for row in rows:
domain = Domain(row.domain, row.domain_start, row.domain_end, row.hmm, row.i_evalue, row.pvalue, row.bgc_probability)
gene.protein.domains.append(domain)
yield gene
@requires("pandas")
def to_dataframe(self) -> "pandas.DataFrame":
"""Convert the feature table to a `~pandas.DataFrame`.
Raises:
ImportError: if the `pandas` module could not be imported.
"""
frame = pandas.DataFrame() # type: ignore
for column in self.__annotations__:
frame[column] = getattr(self, column)
return frame
def __iadd__(self, rhs: "FeatureTable") -> "FeatureTable": # noqa: D105
if not isinstance(rhs, FeatureTable):
return NotImplemented
for col in self.__annotations__:
getattr(self, col).extend(getattr(rhs, col))
return self
def __bool__(self) -> bool: # noqa: D105
return len(self) != 0
def __len__(self) -> int: # noqa: D105
return len(self.sequence_id)
def __iter__(self) -> Iterator[Row]: # noqa: D105
columns = { c: operator.attrgetter(c) for c in self.__annotations__ }
for i in range(len(self)):
row = { c: getter(self)[i] for c, getter in columns.items() }
yield self.Row(**row)
@typing.overload
def __getitem__(self, item: slice) -> "FeatureTable": # noqa: D105
pass
@typing.overload
def __getitem__(self, item: int) -> Row: # noqa: D105
pass
def __getitem__(self, item: Union[slice, int]) -> Union["FeatureTable", "Row"]: # noqa: D105
columns = [getattr(self, col)[item] for col in self.__annotations__]
if isinstance(item, slice):
return type(self)(*columns)
else:
return self.Row(*columns)
def dump(self, fh: TextIO, dialect: str = "excel-tab", header: bool = True) -> None:
"""Write the feature table in CSV format to the given file.
Arguments:
fh (file-like `object`): A writable file-handle opened in text mode
to write the feature table to.
dialect (`str`): The CSV dialect to use. See `csv.list_dialects`
for allowed values.
header (`bool`): Whether or not to include the column header when
writing the table (useful for appending to an existing table).
Defaults to `True`.
"""
writer = csv.writer(fh, dialect=dialect)
columns = list(self.__annotations__)
def _format_product_type(value: "ProductType") -> str:
types = value.unpack() or [ProductType.Unknown]
return ";".join(sorted(map(operator.attrgetter("name"), types)))
# do not write optional columns if they are completely empty
if all(proba is None for proba in self.bgc_probability):
columns.remove("bgc_probability")
if header:
writer.writerow(columns)
for row in self:
writer.writerow([ getattr(row, col) for col in columns ])
@classmethod
def load(cls, fh: TextIO, dialect: str = "excel-tab") -> "FeatureTable":
"""Load a feature table in CSV format from a file handle in text mode.
"""
table = cls()
reader = csv.reader(fh, dialect=dialect)
header = next(reader)
# get the name of each column
columns = {i:col for i, col in enumerate(header)}
# check that if a column is missing, it is one of the optional values
missing = set(cls.__annotations__).difference(columns.values())
missing_required = missing.difference({"bgc_probability"})
if missing_required:
raise ValueError("table is missing columns: {}".format(", ".join(missing_required)))
# extract elements from the CSV rows
for row in reader:
for col in missing:
getattr(table, col).append(None)
for i,value in enumerate(row):
col = columns.get(i)
if col in ("i_evalue", "pvalue", "bgc_probability"):
getattr(table, col).append(float(value))
elif col in ("start", "end", "domain_start", "domain_end"):
getattr(table, col).append(int(value))
elif col in cls.__annotations__:
getattr(table, col).append(value)
return table
def _parse_product_type(value: str) -> "ProductType":
types = [ProductType.__members__[x] for x in value.split(";")]
return ProductType.pack(types)
@dataclass(frozen=True)
class ClusterTable(Dumpable, Sized):
class ClusterTable(Table):
"""A table storing condensed information from several clusters.
"""
......@@ -703,19 +612,19 @@ class ClusterTable(Dumpable, Sized):
bgc_id: List[str] = field(default_factory = list)
start: List[int] = field(default_factory = lambda: array("l")) # type: ignore
end: List[int] = field(default_factory = lambda: array("l")) # type: ignore
average_p: List[float] = field(default_factory = lambda: array("d")) # type: ignore
max_p: List[float] = field(default_factory = lambda: array("d")) # type: ignore
average_p: List[Optional[float]] = field(default_factory = list) # type: ignore
max_p: List[Optional[float]] = field(default_factory = list) # type: ignore
type: List[ProductType] = field(default_factory = list)
alkaloid_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
polyketide_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
ripp_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
saccharide_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
terpene_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
nrp_probability: List[float] = field(default_factory = lambda: array("d")) # type: ignore
alkaloid_probability: List[Optional[float]] = field(default_factory = list) # type: ignore
polyketide_probability: List[Optional[float]] = field(default_factory = list) # type: ignore
ripp_probability: List[Optional[float]] = field(default_factory = list) # type: ignore
saccharide_probability: List[Optional[float]] = field(default_factory = list) # type: ignore
terpene_probability: List[Optional[float]] = field(default_factory = list) # type: ignore
nrp_probability: List[Optional[float]] = field(default_factory = list) # type: ignore
proteins: List[List[str]] = field(default_factory = list)
domains: List[List[str]] = field(default_factory = list)
proteins: List[Optional[List[str]]] = field(default_factory = list)
domains: List[Optional[List[str]]] = field(default_factory = list)
class Row(NamedTuple):
"""A single row in a cluster table.
......@@ -725,17 +634,20 @@ class ClusterTable(Dumpable, Sized):
bgc_id: str
start: int
end: int
average_p: float
max_p: float
average_p: Optional[float]
max_p: Optional[float]
type: ProductType
alkaloid_probability: float
polyketide_probability: float
ripp_probability: float
saccharide_probability: float
terpene_probability: float
nrp_probability: float
proteins: List[str]
domains: List[str]
alkaloid_probability: Optional[float]
polyketide_probability: Optional[float]
ripp_probability: Optional[float]
saccharide_probability: Optional[float]
terpene_probability: Optional[float]
nrp_probability: Optional[float]
proteins: Optional[List[str]]
domains: Optional[List[str]]
_FORMAT_FIELD = {ProductType: _format_product_type, **Table._FORMAT_FIELD}
_PARSE_FIELD = {ProductType: _parse_product_type, **Table._PARSE_FIELD}
@classmethod
def from_clusters(cls, clusters: Iterable[Cluster]) -> "ClusterTable":
......@@ -763,123 +675,64 @@ class ClusterTable(Dumpable, Sized):
table.domains.append(sorted(domains))
return table
def __iadd__(self, rhs: "ClusterTable") -> "ClusterTable": # noqa: D105
if not isinstance(rhs, FeatureTable):
return NotImplemented
for col in self.__annotations__:
getattr(self, col).extend(getattr(rhs, col))
return self
def __len__(self) -> int: # noqa: D105
return len(self.sequence_id)
def __iter__(self) -> Iterator[Row]: # noqa: D105
columns = { c: operator.attrgetter(c) for c in self.__annotations__ }
for i in range(len(self)):
row = { c: getter(self)[i] for c, getter in columns.items() }
yield self.Row(**row)
def __bool__(self) -> bool: # noqa: D105
return len(self) != 0
@dataclass(frozen=True)
class GeneTable(Table):
"""A table storing gene coordinates and optional biosynthetic probabilities.
"""
@typing.overload
def __getitem__(self, item: slice) -> "ClusterTable": # noqa: D105
pass
sequence_id: List[str] = field(default_factory = list)
protein_id: List[str] = field(default_factory = list)
start: List[int] = field(default_factory = lambda: array("l")) # type: ignore
end: List[int] = field(default_factory = lambda: array("l")) # type: ignore
strand: List[str] = field(default_factory = list)
average_p: List[Optional[float]] = field(default_factory = list) # type: ignore
max_p: List[Optional[float]] = field(default_factory = list) # type: ignore
@typing.overload
def __getitem__(self, item: int) -> Row: # noqa: D105
pass
class Row(NamedTuple):
"""A single row in a gene table.
"""
def __getitem__(self, item: Union[int, slice]) -> Union["ClusterTable", Row]: # noqa: D105
columns = [getattr(self, col)[item] for col in self.__annotations__]
if isinstance(item, slice):
return type(self)(*columns)
else:
return self.Row(*columns)
sequence_id: str
protein_id: str
start: int
end: int
strand: str
average_p: Optional[float]
max_p: Optional[float]
def dump(self, fh: TextIO, dialect: str = "excel-tab", header: bool = True) -> None:
"""Write the cluster table in CSV format to the given file.
@classmethod
def from_genes(cls, genes: Iterable[Gene]) -> "GeneTable":
"""Create a new gene table from an iterable of genes.
"""
table = cls()
for gene in genes:
table.sequence_id.append(gene.source.id)
table.protein_id.append(gene.protein.id)
table.start.append(gene.start)
table.end.append(gene.end)
table.strand.append(gene.strand.sign)
table.average_p.append(gene.average_probability)
table.max_p.append(gene.maximum_probability)
return table
Arguments:
fh (file-like `object`): A writable file-handle opened in text mode
to write the cluster table to.
dialect (`str`): The CSV dialect to use. See `csv.list_dialects`
for allowed values.
header (`bool`): Whether or not to include the column header when
writing the table (useful for appending to an existing table).
Defaults to `True`.
def to_genes(self) -> Iterable[Gene]:
"""Convert a gene table to actual genes.
"""
writer = csv.writer(fh, dialect=dialect)
columns = list(self.__annotations__)
row = []
if header:
writer.writerow(columns)
for i in range(len(self)):
row.clear()
for col in columns:
value = getattr(self, col)[i]
if col == "type":
types = value.unpack() or [ProductType.Unknown]
value = ";".join(map(operator.attrgetter("name"), types))
elif isinstance(value, list):
value = ";".join(map(str, value))
row.append(value)
writer.writerow(row)
Since the source sequence cannot be known, a *dummy* sequence is
built for each gene of size ``gene.end``, so that each gene can still
be converted to a `~Bio.SeqRecord.SeqRecord` if needed.
@classmethod
def load(cls, fh: TextIO, dialect: str = "excel-tab") -> "ClusterTable":
"""Load a cluster table in CSV format from a file handle in text mode.
"""
table = cls()
reader = csv.reader(fh, dialect=dialect)
header = next(reader)
# get the name of each column
columns = {i:col for i, col in enumerate(header)}
# check that if a column is missing, it is one of the optional values
missing = set(cls.__annotations__).difference(columns.values())
missing_required = missing.difference({
"average_p",
"max_p",
"type",
"alkaloid_probability",
"polyketide_probability",
"ripp_probability",
"saccharide_probability",
"terpene_probability",
"nrp_probability",
"proteins",
"domains"
})
if missing_required:
raise ValueError("table is missing columns: {}".format(", ".join(missing_required)))
# extract elements from the CSV rows
for row in reader:
for col in missing:
if col in ("proteins", "domains"):
getattr(table, col).append(list())
elif col in ("average_p", "max_p"):
getattr(table, col).append(1.0)
elif col == "type":
table.type.append(ProductType.Unknown)
else:
getattr(table, col).append(0.0)
for i,value in enumerate(row):
col = columns.get(i)
if col in ("i_evalue", "pvalue"):
getattr(table, col).append(float(value))
elif col in ("start", "end", "domain_start", "domain_end"):
getattr(table, col).append(int(value))
elif col == "type":
types = [ProductType.__members__[x] for x in value.split(";")]
table.type.append(ProductType.pack(types))
elif col in cls.__annotations__:
if col.endswith(("_p", "_probability")):
getattr(table, col).append(float(value))
else:
getattr(table, col).append(value)
for row in self:
source = SeqRecord(id=row.sequence_id, seq=_UnknownSeq())
strand = Strand.Coding if row.strand == "+" else Strand.Reverse
seq = Seq("X" * (row.end - row.start // 3))
protein = Protein(row.protein_id, seq=_UnknownSeq())
yield Gene(source, row.start, row.end, strand, protein, _probability=row.average_p)
return table
def __len__(self) -> int: # noqa: D105
return len(self.protein_id)
This diff is collapsed.
This diff is collapsed.
......@@ -25,7 +25,7 @@ class TestTrain(TestCommand, unittest.TestCase):
def test_train_feature_type_domain(self):
base = os.path.join(self.folder, "data", "mibig-2.0.proG2")
clusters, features, genes = f"{base}.clusters.tsv", f"{base}.features.tsv", f"{base}.gff"
clusters, features, genes = f"{base}.clusters.tsv", f"{base}.features.tsv", f"{base}.genes.tsv"
argv = [
"-vv", "train", "-f", features, "-c", clusters, "-o", self.tmpdir,
......@@ -41,7 +41,7 @@ class TestTrain(TestCommand, unittest.TestCase):
def test_train_feature_type_protein(self):
base = os.path.join(self.folder, "data", "mibig-2.0.proG2")
clusters, features, genes = f"{base}.clusters.tsv", f"{base}.features.tsv", f"{base}.gff"
clusters, features, genes = f"{base}.clusters.tsv", f"{base}.features.tsv", f"{base}.genes.tsv"
argv = [
"-vv", "train", "-f", features, "-c", clusters, "-o", self.tmpdir,
......
"""Test `gecco.model.GeneTable` objects.
"""
import itertools
import io
import os
import unittest
import warnings
from unittest import mock
import Bio.SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from gecco.model import Gene, GeneTable, Strand, Protein, Domain, _UnknownSeq
class TestGeneTable(unittest.TestCase):
def test_from_genes(self):
genes = [
Gene(
source=SeqRecord(id="seq1", seq=_UnknownSeq()),
start=1,
end=100,
strand=Strand.Coding,
protein=Protein(
id="seq1_1",
seq=None,
domains=[
Domain(
name="PF00106",
start=1,
end=20,
hmm="Pfam",
i_evalue=0.0,
pvalue=0.0,
probability=0.8
),
]
)
),
Gene(
source=SeqRecord(id="seq2", seq=_UnknownSeq()),
start=3,
end=300,
strand=Strand.Coding,
protein=Protein(
id="seq2_1",
seq=None,
domains=[]
),
_probability=0.6
),
Gene(
source=SeqRecord(id="seq2", seq=_UnknownSeq()),
start=4,
end=49,
strand=Strand.Reverse,
protein=Protein(
id="seq2_2",
seq=None,
domains=[],
)
),
]
table = GeneTable.from_genes(genes)
rows = list(table)
self.assertEqual(len(table), 3)
self.assertEqual(len(rows), 3)
row1 = rows[0]
self.assertEqual(row1.sequence_id, "seq1")
self.assertEqual(row1.protein_id, "seq1_1")
self.assertEqual(row1.start, 1)
self.assertEqual(row1.end, 100)
self.assertEqual(row1.strand, "+")
self.assertEqual(row1.average_p, 0.8)
self.assertEqual(row1.max_p, 0.8)
row2 = rows[1]
self.assertEqual(row2.sequence_id, "seq2")
self.assertEqual(row2.protein_id, "seq2_1")
self.assertEqual(row2.start, 3)
self.assertEqual(row2.end, 300)
self.assertEqual(row2.strand, "+")
self.assertEqual(row2.average_p, 0.6)
self.assertEqual(row2.max_p, 0.6)
row3 = rows[2]
self.assertEqual(row3.sequence_id, "seq2")
self.assertEqual(row3.protein_id, "seq2_2")
self.assertEqual(row3.start, 4)
self.assertEqual(row3.end, 49)
self.assertEqual(row3.strand, "-")
self.assertIs(row3.average_p, None)
self.assertIs(row3.max_p, None)
def test_to_genes(self):
table = GeneTable(
sequence_id=["seq1", "seq2", "seq2"],
protein_id=["seq1_1", "seq2_1", "seq2_2"],
start=[100, 200, 300],
end=[160, 260, 360],
strand=["+", "-", "+"],
average_p=[0.6, 0.2, None],
max_p=[0.8, 0.2, None],
)
genes = list(table.to_genes())
self.assertEqual(genes[0].source.id, "seq1")
self.assertEqual(genes[0].protein.id, "seq1_1")
self.assertEqual(genes[0].start, 100)
self.assertEqual(genes[0].end, 160)
self.assertEqual(genes[0].strand, Strand.Coding)
self.assertEqual(genes[0].average_probability, 0.6)
self.assertEqual(genes[1].source.id, "seq2")
self.assertEqual(genes[1].protein.id, "seq2_1")
self.assertEqual(genes[1].start, 200)
self.assertEqual(genes[1].end, 260)
self.assertEqual(genes[1].strand, Strand.Reverse)
self.assertEqual(genes[1].average_probability, 0.2)
self.assertEqual(genes[2].source.id, "seq2")
self.assertEqual(genes[2].protein.id, "seq2_2")
self.assertEqual(genes[2].start, 300)
self.assertEqual(genes[2].end, 360)
self.assertEqual(genes[2].strand, Strand.Coding)
self.assertIs(genes[2].average_probability, None)
def test_dump(self):
table = GeneTable(
sequence_id=["seq_1", "seq_2", "seq_2"],
protein_id=["seq1_1", "seq2_1", "seq2_2"],
start=[100, 200, 300],
end=[160, 260, 360],
strand=["+", "-", "+"],
average_p=[0.6, 0.2, None],
max_p=[0.8, 0.2, None],
)
buffer = io.StringIO()
table.dump(buffer)
lines = buffer.getvalue().splitlines()
self.assertEqual(
lines[0],
"\t".join(["sequence_id", "protein_id", "start", "end", "strand", "average_p", "max_p"])
)
self.assertEqual(
lines[1],
"\t".join(["seq_1", "seq1_1", "100", "160", "+", "0.6", "0.8"])
)
self.assertEqual(
lines[2],
"\t".join(["seq_2", "seq2_1", "200", "260", "-", "0.2", "0.2"])
)
self.assertEqual(
lines[3],
"\t".join(["seq_2", "seq2_2", "300", "360", "+", "", ""])
)
def test_dump_no_probability(self):
table = GeneTable(
sequence_id=["seq_1", "seq_2"],
protein_id=["seq1_1", "seq2_1"],
start=[100, 200],
end=[160, 260],
strand=["+", "-"],
average_p=[None, None],
max_p=[None, None],
)
buffer = io.StringIO()
table.dump(buffer)
lines = buffer.getvalue().splitlines()
self.assertEqual(
lines[0],
"\t".join(["sequence_id", "protein_id", "start", "end", "strand"])
)
self.assertEqual(
lines[1],
"\t".join(["seq_1", "seq1_1", "100", "160", "+"])
)
self.assertEqual(
lines[2],
"\t".join(["seq_2", "seq2_1", "200", "260", "-"])
)
def test_load(self):
lines = "\n".join([
"\t".join(["sequence_id", "protein_id", "start", "end", "strand", "average_p", "max_p"]),
"\t".join(["seq1", "seq1_1", "100", "160", "+", "0.6", "0.8"]),
"\t".join(["seq2", "seq2_1", "200", "260", "-", "", ""]),
])
table = GeneTable.load(io.StringIO(lines))
self.assertEqual(table[0].sequence_id, "seq1")
self.assertEqual(table[0].protein_id, "seq1_1")
self.assertEqual(table[0].start, 100)
self.assertEqual(table[0].end, 160)
self.assertEqual(table[0].strand, "+")
self.assertEqual(table[0].average_p, 0.6)
self.assertEqual(table[0].max_p, 0.8)
self.assertEqual(table[1].sequence_id, "seq2")
self.assertEqual(table[1].protein_id, "seq2_1")
self.assertEqual(table[1].start, 200)
self.assertEqual(table[1].end, 260)
self.assertEqual(table[1].strand, "-")
self.assertEqual(table[1].average_p, None)
self.assertEqual(table[1].max_p, None)