From 85c532e45c983d4683bd6fe25c2aba1e5fb6bf26 Mon Sep 17 00:00:00 2001 From: Martin Larralde <martin.larralde@embl.de> Date: Tue, 8 Oct 2024 14:46:33 +0200 Subject: [PATCH] Fix use of deprecated `TopHits` properties in pyhmmer.tests --- pyhmmer/tests/test_hmmer.py | 4 +-- pyhmmer/tests/test_plan7/test_alignment.py | 2 +- pyhmmer/tests/test_plan7/test_pipeline.py | 36 +++++++++++----------- pyhmmer/tests/test_plan7/test_tophits.py | 35 ++++++++++++--------- 4 files changed, 41 insertions(+), 36 deletions(-) diff --git a/pyhmmer/tests/test_hmmer.py b/pyhmmer/tests/test_hmmer.py index b13eba1f..9bf2bf70 100644 --- a/pyhmmer/tests/test_hmmer.py +++ b/pyhmmer/tests/test_hmmer.py @@ -764,7 +764,7 @@ class TestHMMScan(unittest.TestCase): seqs_path, digital=True, alphabet=hmms[0].alphabet ) as seqs_file: for hits in pyhmmer.hmmer.hmmscan(seqs_file, hmms, cpus=1): - expected_lines = expected.get(hits.query_name.decode()) + expected_lines = expected.get(hits.query.name.decode()) if expected_lines is None: self.assertEqual(len(hits), 0) continue @@ -798,7 +798,7 @@ class TestHMMScan(unittest.TestCase): with SequenceFile(seqs_path, digital=True) as seqs_file: with HMMPressedFile(db_path) as pressed_file: for hits in pyhmmer.hmmer.hmmscan(seqs_file, pressed_file, cpus=1): - expected_lines = expected.get(hits.query_name.decode()) + expected_lines = expected.get(hits.query.name.decode()) if expected_lines is None: self.assertEqual(len(hits), 0) continue diff --git a/pyhmmer/tests/test_plan7/test_alignment.py b/pyhmmer/tests/test_plan7/test_alignment.py index 4e86fb9a..adcd672e 100644 --- a/pyhmmer/tests/test_plan7/test_alignment.py +++ b/pyhmmer/tests/test_plan7/test_alignment.py @@ -37,7 +37,7 @@ class TestAlignment(unittest.TestCase): rendered = str(self.ali) lines = rendered.splitlines() self.assertEqual(len(lines), 5) - self.assertTrue(lines[1].strip().startswith(self.hits.query_name.decode())) + self.assertTrue(lines[1].strip().startswith(self.hits.query.name.decode())) self.assertTrue(lines[3].strip().startswith(self.hits[0].name.decode())) @unittest.skipIf(sys.implementation.name == "pypy", "`getsizeof` not supported on PyPY") diff --git a/pyhmmer/tests/test_plan7/test_pipeline.py b/pyhmmer/tests/test_plan7/test_pipeline.py index 594ae9b9..139694bc 100644 --- a/pyhmmer/tests/test_plan7/test_pipeline.py +++ b/pyhmmer/tests/test_plan7/test_pipeline.py @@ -72,8 +72,8 @@ class TestSearchPipeline(unittest.TestCase): pipeline = Pipeline(alphabet=self.alphabet) hits = pipeline.search_hmm(hmm, self.references) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, hmm.name) - self.assertEqual(hits.query_accession, hmm.accession) + self.assertEqual(hits.query.name, hmm.name) + self.assertEqual(hits.query.accession, hmm.accession) def test_search_hmm_file(self): seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1") @@ -83,8 +83,8 @@ class TestSearchPipeline(unittest.TestCase): with SequenceFile(self.reference_path, digital=True, alphabet=self.alphabet) as seqs_file: hits = pipeline.search_hmm(hmm, seqs_file) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, hmm.name) - self.assertEqual(hits.query_accession, hmm.accession) + self.assertEqual(hits.query.name, hmm.name) + self.assertEqual(hits.query.accession, hmm.accession) def test_search_hmm_unnamed(self): # make sure `Pipeline.search_hmm` doesn't crash when given an HMM with no name @@ -94,16 +94,16 @@ class TestSearchPipeline(unittest.TestCase): hmm.accession = None pipeline = Pipeline(alphabet=self.alphabet) hits = pipeline.search_hmm(hmm, self.references) - self.assertEqual(hits.query_name, b"test") - self.assertIs(hits.query_accession, None) + self.assertEqual(hits.query.name, b"test") + self.assertIs(hits.query.accession, None) def test_search_seq_block(self): seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1", accession=b"SQ001") pipeline = Pipeline(alphabet=self.alphabet) hits = pipeline.search_seq(seq.digitize(self.alphabet), self.references) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, seq.name) - # self.assertEqual(hits.query_accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession... + self.assertEqual(hits.query.name, seq.name) + self.assertEqual(hits.query.accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession... def test_search_seq_file(self): seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1", accession=b"SQ001") @@ -111,23 +111,23 @@ class TestSearchPipeline(unittest.TestCase): with SequenceFile(self.reference_path, digital=True, alphabet=self.alphabet) as seqs_file: hits = pipeline.search_seq(seq.digitize(self.alphabet), seqs_file) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, seq.name) - # self.assertEqual(hits.query_accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession... + self.assertEqual(hits.query.name, seq.name) + self.assertEqual(hits.query.accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession... def test_search_msa_block(self): pipeline = Pipeline(alphabet=self.alphabet) hits = pipeline.search_msa(self.msa, self.references) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, self.msa.name) - self.assertEqual(hits.query_accession, self.msa.accession) + self.assertEqual(hits.query.name, self.msa.name) + self.assertEqual(hits.query.accession, self.msa.accession) def test_search_msa_file(self): pipeline = Pipeline(alphabet=self.alphabet) with SequenceFile(self.reference_path, digital=True, alphabet=self.alphabet) as seqs_file: hits = pipeline.search_msa(self.msa, seqs_file) self.assertEqual(len(hits), 1) - self.assertEqual(hits.query_name, self.msa.name) - self.assertEqual(hits.query_accession, self.msa.accession) + self.assertEqual(hits.query.name, self.msa.name) + self.assertEqual(hits.query.accession, self.msa.accession) def test_Z(self): seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1") @@ -330,8 +330,8 @@ class TestLongTargetsPipeline(unittest.TestCase): pipeline = LongTargetsPipeline(alphabet=dna) hits = pipeline.search_hmm(hmm, targets) - self.assertEqual(hits.query_name, hmm.name) - self.assertEqual(hits.query_accession, hmm.accession) + self.assertEqual(hits.query.name, hmm.name) + self.assertEqual(hits.query.accession, hmm.accession) def test_search_hmm_file(self): dna = Alphabet.dna() @@ -350,8 +350,8 @@ class TestLongTargetsPipeline(unittest.TestCase): pipeline = LongTargetsPipeline(alphabet=dna) hits = pipeline.search_hmm(hmm, targets) - self.assertEqual(hits.query_name, hmm.name) - self.assertEqual(hits.query_accession, hmm.accession) + self.assertEqual(hits.query.name, hmm.name) + self.assertEqual(hits.query.accession, hmm.accession) def test_search_hmm_alphabet_mismatch(self): dna = Alphabet.dna() diff --git a/pyhmmer/tests/test_plan7/test_tophits.py b/pyhmmer/tests/test_plan7/test_tophits.py index c203190b..7a6ecfc4 100644 --- a/pyhmmer/tests/test_plan7/test_tophits.py +++ b/pyhmmer/tests/test_plan7/test_tophits.py @@ -83,9 +83,9 @@ class TestTopHits(unittest.TestCase): self.assertEqual(sum(d.included for d in h1.domains), sum(d.included for d in h2.domains)) def assertHitsEqual(self, hits1, hits2): - self.assertEqual(hits1.query_name, hits2.query_name) - self.assertEqual(hits1.query_accession, hits2.query_accession) - self.assertEqual(hits1.query_length, hits2.query_length) + self.assertEqual(hits1.query.name, hits2.query.name) + self.assertEqual(hits1.query.accession, hits2.query.accession) + self.assertEqual(hits1.query, hits2.query) self.assertEqual(len(hits1), len(hits2)) self.assertEqual(len(hits1.included), len(hits2.included)) self.assertEqual(len(hits1.reported), len(hits2.reported)) @@ -180,9 +180,9 @@ class TestTopHits(unittest.TestCase): self.assertEqual(merged_empty.domZ, 0.0) merged = empty.merge(self.hits) - self.assertEqual(merged.query_name, self.hits.query_name) - self.assertEqual(merged.query_length, self.hits.query_length) - self.assertEqual(merged.query_accession, self.hits.query_accession) + self.assertEqual(merged.query.name, self.hits.query.name) + self.assertEqual(merged.query.M, self.hits.query.M) + self.assertEqual(merged.query.accession, self.hits.query.accession) self.assertEqual(merged.E, self.hits.E) self.assertHitsEqual(merged, self.hits) @@ -206,9 +206,10 @@ class TestTopHits(unittest.TestCase): self.assertEqual(merged.searched_sequences, hits.searched_sequences) self.assertEqual(merged.Z, hits.Z) self.assertEqual(merged.domZ, hits.domZ) - self.assertEqual(merged.query_name, hits.query_name) - self.assertEqual(merged.query_length, hits.query_length) - self.assertEqual(merged.query_accession, hits.query_accession) + self.assertIs(merged.query, hits.query) + self.assertEqual(merged.query.name, hits.query.name) + self.assertEqual(merged.query.M, hits.query.M) + self.assertEqual(merged.query.accession, hits.query.accession) self.assertEqual(merged.E, hits.E) self.assertEqual(merged.domE, hits.domE) @@ -244,9 +245,10 @@ class TestTopHits(unittest.TestCase): self.assertEqual(merged.searched_sequences, hits.searched_sequences) self.assertEqual(merged.Z, hits.Z) self.assertEqual(merged.domZ, hits.domZ) - self.assertEqual(merged.query_name, hits.query_name) - self.assertEqual(merged.query_length, hits.query_length) - self.assertEqual(merged.query_accession, hits.query_accession) + self.assertIs(merged.query, hits.query) + self.assertEqual(merged.query.name, hits.query.name) + self.assertEqual(merged.query.M, hits.query.M) + self.assertEqual(merged.query.accession, hits.query.accession) self.assertEqual(merged.E, hits.E) self.assertEqual(merged.domE, hits.domE) @@ -340,14 +342,17 @@ class TestTopHits(unittest.TestCase): pickled = pickle.loads(pickle.dumps(self.hits)) self.assertHitsEqual(pickled, self.hits) + def test_query(self): + self.assertIs(self.hits.query, self.hmm) + def test_query_name(self): - self.assertEqual(self.hits.query_name, self.hmm.name) + self.assertEqual(self.hits.query.name, self.hmm.name) def test_query_accession(self): - self.assertEqual(self.hits.query_accession, self.hmm.accession) + self.assertEqual(self.hits.query.accession, self.hmm.accession) def test_query_length(self): - self.assertEqual(self.hits.query_length, self.hmm.M) + self.assertEqual(self.hits.query.M, self.hmm.M) def test_write_target(self): buffer = io.BytesIO() -- GitLab