Skip to content
Snippets Groups Projects
Commit af126fbb authored by Martin Larralde's avatar Martin Larralde
Browse files

Require a query when creating a `TopHits` object

parent e444ee4b
No related branches found
No related tags found
No related merge requests found
......@@ -5763,7 +5763,7 @@ cdef class Pipeline:
cdef P7_OPROFILE* om
cdef int status
cdef int allocM
cdef TopHits hits = TopHits()
cdef TopHits hits = TopHits(query)
# check that the sequence file is in digital mode
if SearchTargets is SequenceFile:
......@@ -6141,7 +6141,7 @@ cdef class Pipeline:
"""
cdef int allocM
cdef Profile profile
cdef TopHits hits = TopHits()
cdef TopHits hits = TopHits(query)
assert self._pli != NULL
......@@ -6840,7 +6840,7 @@ cdef class LongTargetsPipeline(Pipeline):
cdef HMM hmm
cdef int max_length
cdef ScoreData scoredata = ScoreData.__new__(ScoreData)
cdef TopHits hits = TopHits()
cdef TopHits hits = TopHits(query)
cdef P7_HIT* hit = NULL
cdef P7_OPROFILE* om = NULL
......@@ -7680,12 +7680,13 @@ cdef class TopHits:
self._query = None
memset(&self._pli, 0, sizeof(P7_PIPELINE))
def __init__(self):
"""__init__(self)\n--\n
def __init__(self, object query not None):
"""__init__(self, query)\n--\n
Create an empty `TopHits` instance.
"""
self._query = query
with nogil:
# free allocated memory (in case __init__ is called more than once)
libhmmer.p7_tophits.p7_tophits_Destroy(self._th)
......@@ -7727,7 +7728,7 @@ cdef class TopHits:
return self.merge(other)
def __reduce__(self):
return TopHits, (), self.__getstate__()
return TopHits, (self.query,), self.__getstate__()
def __getstate__(self):
assert self._th != NULL
......@@ -7749,7 +7750,6 @@ cdef class TopHits:
hits.append(offset)
return {
"query": self._query,
"unsrt": unsrt,
"hit": hits,
"Nalloc": self._th.Nalloc,
......@@ -7814,9 +7814,6 @@ cdef class TopHits:
cdef size_t offset
cdef VectorU8 hit_state
# record query name and accession
self._query = state["query"]
# deallocate current data if needed
if self._th != NULL:
libhmmer.p7_tophits.p7_tophits_Destroy(self._th)
......@@ -7923,6 +7920,9 @@ cdef class TopHits:
.. versionadded:: 0.6.1
.. deprecated:: 0.10.10
Use ``TopHits.query`` to access the original query directly.
"""
if self._query is None:
return None
......@@ -7934,6 +7934,9 @@ cdef class TopHits:
.. versionadded:: 0.6.1
.. deprecated:: 0.10.10
Use ``TopHits.query`` to access the original query directly.
"""
if self._query is None:
return None
......@@ -7945,11 +7948,20 @@ cdef class TopHits:
.. versionadded:: 0.10.5
.. deprecated:: 0.10.10
Use ``TopHits.query`` to access the original query directly.
"""
if self._query is None:
return 0
return self._query.M if isinstance(self._query, HMM) else len(self._query)
@property
def query(self):
"""`object`: The query object these hits were obtained for.
"""
return self._query
@property
def Z(self):
"""`float`: The effective number of targets searched.
......@@ -8189,7 +8201,7 @@ cdef class TopHits:
raise ValueError("Trying to merge `TopHits` obtained from pipelines manually configured to different `domZ` values.")
# check threshold modes are consistent
if self._pli.by_E != other.by_E:
raise ValueError("Trying to merge `TopHits` obtained from pipelines with different reporting threshold modes")
raise ValueError(f"Trying to merge `TopHits` obtained from pipelines with different reporting threshold modes: {self._pli.by_E} != {other.by_E}")
elif self._pli.dom_by_E != other.dom_by_E:
raise ValueError("Trying to merge `TopHits` obtained from pipelines with different domain reporting threshold modes")
elif self._pli.inc_by_E != other.inc_by_E:
......@@ -8544,17 +8556,17 @@ cdef class TopHits:
# not referenced anywhere else)
other_copy = other.copy()
# check that names/accessions are consistent
if merged._query != other._query:
raise ValueError("Trying to merge `TopHits` obtained from different queries")
# just store the copy if merging inside an empty uninitialized `TopHits`
if merged._th.N == 0 and merged._query is None:
if merged._th.N == 0:
merged._query = other._query
memcpy(&merged._pli, &other_copy._pli, sizeof(P7_PIPELINE))
merged._th, other_copy._th = other_copy._th, merged._th
continue
# check that names/accessions are consistent
if merged._query != other._query:
raise ValueError("Trying to merge `TopHits` obtained from different queries")
# check that the parameters are the same
merged._check_threshold_parameters(&other._pli)
......
......@@ -45,22 +45,6 @@ def load_tests(loader, tests, ignore):
_current_cwd = os.getcwd()
_daemon_client = mock.patch("pyhmmer.daemon.Client")
def setUp(self):
warnings.simplefilter("ignore")
os.chdir(os.path.realpath(os.path.join(__file__, "..", "..")))
# mock the HMMPGMD client to show usage examples without having
# to actually spawn an HMMPGMD server in the background
_client = _daemon_client.__enter__()
_client.return_value = _client
_client.__enter__.return_value = _client
_client.connect.return_value = None
_client.search_hmm.return_value = pyhmmer.plan7.TopHits()
def tearDown(self):
os.chdir(_current_cwd)
warnings.simplefilter(warnings.defaultaction)
_daemon_client.__exit__(None, None, None)
# doctests are not compatible with `green`, so we may want to bail out
# early if `green` is running the tests
if sys.argv[0].endswith("green"):
......@@ -86,6 +70,22 @@ def load_tests(loader, tests, ignore):
with pyhmmer.easel.SequenceFile(seq_path, digital=True) as seq_file:
reductase = next(seq for seq in seq_file if b"P12748" in seq.name)
def setUp(self):
warnings.simplefilter("ignore")
os.chdir(os.path.realpath(os.path.join(__file__, "..", "..")))
# mock the HMMPGMD client to show usage examples without having
# to actually spawn an HMMPGMD server in the background
_client = _daemon_client.__enter__()
_client.return_value = _client
_client.__enter__.return_value = _client
_client.connect.return_value = None
_client.search_hmm.return_value = pyhmmer.plan7.TopHits(thioesterase)
def tearDown(self):
os.chdir(_current_cwd)
warnings.simplefilter(warnings.defaultaction)
_daemon_client.__exit__(None, None, None)
# recursively traverse all library submodules and load tests from them
packages = [None, pyhmmer]
......
......@@ -114,7 +114,7 @@ class TestTopHits(unittest.TestCase):
self.assertEqual(search_hits.mode, "search")
def test_bool(self):
self.assertFalse(pyhmmer.plan7.TopHits())
self.assertFalse(pyhmmer.plan7.TopHits(self.hmm))
self.assertTrue(self.hits)
def test_index_error(self):
......@@ -131,36 +131,36 @@ class TestTopHits(unittest.TestCase):
self.assertEqual(dom.name, dom_last.name)
def test_Z(self):
empty = TopHits()
empty = TopHits(self.hmm)
self.assertEqual(empty.Z, 0)
self.assertEqual(self.hits.Z, len(self.seqs))
def test_strand(self):
empty = TopHits()
empty = TopHits(self.hmm)
self.assertIs(empty.strand, None)
def test_searched_sequences(self):
empty = TopHits()
empty = TopHits(self.hmm)
self.assertEqual(empty.searched_sequences, 0)
self.assertEqual(self.hits.searched_sequences, len(self.seqs))
def test_searched_nodes(self):
empty = TopHits()
empty = TopHits(self.hmm)
self.assertEqual(empty.searched_nodes, 0)
self.assertEqual(self.hits.searched_nodes, self.hmm.M)
def test_searched_residues(self):
empty = TopHits()
empty = TopHits(self.hmm)
self.assertEqual(empty.searched_residues, 0)
self.assertEqual(self.hits.searched_residues, sum(map(len, self.seqs)))
def test_searched_models(self):
empty = TopHits()
empty = TopHits(self.hmm)
self.assertEqual(empty.searched_sequences, 0)
self.assertEqual(self.hits.searched_models, 1)
def test_merge_empty(self):
empty = TopHits()
empty = TopHits(self.hmm)
self.assertFalse(empty.long_targets)
self.assertEqual(empty.Z, 0.0)
self.assertEqual(empty.domZ, 0.0)
......@@ -170,7 +170,7 @@ class TestTopHits(unittest.TestCase):
self.assertEqual(empty2.Z, 0.0)
self.assertEqual(empty2.domZ, 0.0)
merged_empty = empty.merge(TopHits())
merged_empty = empty.merge(TopHits(self.hmm))
self.assertHitsEqual(merged_empty, empty)
self.assertEqual(merged_empty.searched_residues, 0)
self.assertEqual(merged_empty.searched_sequences, 0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment