From 85c532e45c983d4683bd6fe25c2aba1e5fb6bf26 Mon Sep 17 00:00:00 2001
From: Martin Larralde <martin.larralde@embl.de>
Date: Tue, 8 Oct 2024 14:46:33 +0200
Subject: [PATCH] Fix use of deprecated `TopHits` properties in pyhmmer.tests

---
 pyhmmer/tests/test_hmmer.py                |  4 +--
 pyhmmer/tests/test_plan7/test_alignment.py |  2 +-
 pyhmmer/tests/test_plan7/test_pipeline.py  | 36 +++++++++++-----------
 pyhmmer/tests/test_plan7/test_tophits.py   | 35 ++++++++++++---------
 4 files changed, 41 insertions(+), 36 deletions(-)

diff --git a/pyhmmer/tests/test_hmmer.py b/pyhmmer/tests/test_hmmer.py
index b13eba1f..9bf2bf70 100644
--- a/pyhmmer/tests/test_hmmer.py
+++ b/pyhmmer/tests/test_hmmer.py
@@ -764,7 +764,7 @@ class TestHMMScan(unittest.TestCase):
             seqs_path, digital=True, alphabet=hmms[0].alphabet
         ) as seqs_file:
             for hits in pyhmmer.hmmer.hmmscan(seqs_file, hmms, cpus=1):
-                expected_lines = expected.get(hits.query_name.decode())
+                expected_lines = expected.get(hits.query.name.decode())
                 if expected_lines is None:
                     self.assertEqual(len(hits), 0)
                     continue
@@ -798,7 +798,7 @@ class TestHMMScan(unittest.TestCase):
         with SequenceFile(seqs_path, digital=True) as seqs_file:
             with HMMPressedFile(db_path) as pressed_file:
                 for hits in pyhmmer.hmmer.hmmscan(seqs_file, pressed_file, cpus=1):
-                    expected_lines = expected.get(hits.query_name.decode())
+                    expected_lines = expected.get(hits.query.name.decode())
                     if expected_lines is None:
                         self.assertEqual(len(hits), 0)
                         continue
diff --git a/pyhmmer/tests/test_plan7/test_alignment.py b/pyhmmer/tests/test_plan7/test_alignment.py
index 4e86fb9a..adcd672e 100644
--- a/pyhmmer/tests/test_plan7/test_alignment.py
+++ b/pyhmmer/tests/test_plan7/test_alignment.py
@@ -37,7 +37,7 @@ class TestAlignment(unittest.TestCase):
         rendered = str(self.ali)
         lines = rendered.splitlines()
         self.assertEqual(len(lines), 5)
-        self.assertTrue(lines[1].strip().startswith(self.hits.query_name.decode()))
+        self.assertTrue(lines[1].strip().startswith(self.hits.query.name.decode()))
         self.assertTrue(lines[3].strip().startswith(self.hits[0].name.decode()))
 
     @unittest.skipIf(sys.implementation.name == "pypy", "`getsizeof` not supported on PyPY")
diff --git a/pyhmmer/tests/test_plan7/test_pipeline.py b/pyhmmer/tests/test_plan7/test_pipeline.py
index 594ae9b9..139694bc 100644
--- a/pyhmmer/tests/test_plan7/test_pipeline.py
+++ b/pyhmmer/tests/test_plan7/test_pipeline.py
@@ -72,8 +72,8 @@ class TestSearchPipeline(unittest.TestCase):
         pipeline = Pipeline(alphabet=self.alphabet)
         hits = pipeline.search_hmm(hmm, self.references)
         self.assertEqual(len(hits), 1)
-        self.assertEqual(hits.query_name, hmm.name)
-        self.assertEqual(hits.query_accession, hmm.accession)
+        self.assertEqual(hits.query.name, hmm.name)
+        self.assertEqual(hits.query.accession, hmm.accession)
 
     def test_search_hmm_file(self):
         seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1")
@@ -83,8 +83,8 @@ class TestSearchPipeline(unittest.TestCase):
         with SequenceFile(self.reference_path, digital=True, alphabet=self.alphabet) as seqs_file:
             hits = pipeline.search_hmm(hmm, seqs_file)
         self.assertEqual(len(hits), 1)
-        self.assertEqual(hits.query_name, hmm.name)
-        self.assertEqual(hits.query_accession, hmm.accession)
+        self.assertEqual(hits.query.name, hmm.name)
+        self.assertEqual(hits.query.accession, hmm.accession)
 
     def test_search_hmm_unnamed(self):
         # make sure `Pipeline.search_hmm` doesn't crash when given an HMM with no name
@@ -94,16 +94,16 @@ class TestSearchPipeline(unittest.TestCase):
         hmm.accession = None
         pipeline = Pipeline(alphabet=self.alphabet)
         hits = pipeline.search_hmm(hmm, self.references)
-        self.assertEqual(hits.query_name, b"test")
-        self.assertIs(hits.query_accession, None)
+        self.assertEqual(hits.query.name, b"test")
+        self.assertIs(hits.query.accession, None)
 
     def test_search_seq_block(self):
         seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1", accession=b"SQ001")
         pipeline = Pipeline(alphabet=self.alphabet)
         hits = pipeline.search_seq(seq.digitize(self.alphabet), self.references)
         self.assertEqual(len(hits), 1)
-        self.assertEqual(hits.query_name, seq.name)
-        # self.assertEqual(hits.query_accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession...
+        self.assertEqual(hits.query.name, seq.name)
+        self.assertEqual(hits.query.accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession...
 
     def test_search_seq_file(self):
         seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1", accession=b"SQ001")
@@ -111,23 +111,23 @@ class TestSearchPipeline(unittest.TestCase):
         with SequenceFile(self.reference_path, digital=True, alphabet=self.alphabet) as seqs_file:
             hits = pipeline.search_seq(seq.digitize(self.alphabet), seqs_file)
         self.assertEqual(len(hits), 1)
-        self.assertEqual(hits.query_name, seq.name)
-        # self.assertEqual(hits.query_accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession...
+        self.assertEqual(hits.query.name, seq.name)
+        self.assertEqual(hits.query.accession, seq.accession) # NOTE: p7_SingleBuilder doesn't copy the accession...
 
     def test_search_msa_block(self):
         pipeline = Pipeline(alphabet=self.alphabet)
         hits = pipeline.search_msa(self.msa, self.references)
         self.assertEqual(len(hits), 1)
-        self.assertEqual(hits.query_name, self.msa.name)
-        self.assertEqual(hits.query_accession, self.msa.accession)
+        self.assertEqual(hits.query.name, self.msa.name)
+        self.assertEqual(hits.query.accession, self.msa.accession)
 
     def test_search_msa_file(self):
         pipeline = Pipeline(alphabet=self.alphabet)
         with SequenceFile(self.reference_path, digital=True, alphabet=self.alphabet) as seqs_file:
             hits = pipeline.search_msa(self.msa, seqs_file)
         self.assertEqual(len(hits), 1)
-        self.assertEqual(hits.query_name, self.msa.name)
-        self.assertEqual(hits.query_accession, self.msa.accession)
+        self.assertEqual(hits.query.name, self.msa.name)
+        self.assertEqual(hits.query.accession, self.msa.accession)
 
     def test_Z(self):
         seq = TextSequence(sequence="IRGIYNIIKSVAEDIEIGIIPPSKDHVTISSFKSPRIADT", name=b"seq1")
@@ -330,8 +330,8 @@ class TestLongTargetsPipeline(unittest.TestCase):
 
         pipeline = LongTargetsPipeline(alphabet=dna)
         hits = pipeline.search_hmm(hmm, targets)
-        self.assertEqual(hits.query_name, hmm.name)
-        self.assertEqual(hits.query_accession, hmm.accession)
+        self.assertEqual(hits.query.name, hmm.name)
+        self.assertEqual(hits.query.accession, hmm.accession)
 
     def test_search_hmm_file(self):
         dna = Alphabet.dna()
@@ -350,8 +350,8 @@ class TestLongTargetsPipeline(unittest.TestCase):
                 pipeline = LongTargetsPipeline(alphabet=dna)
                 hits = pipeline.search_hmm(hmm, targets)
 
-        self.assertEqual(hits.query_name, hmm.name)
-        self.assertEqual(hits.query_accession, hmm.accession)
+        self.assertEqual(hits.query.name, hmm.name)
+        self.assertEqual(hits.query.accession, hmm.accession)
 
     def test_search_hmm_alphabet_mismatch(self):
         dna = Alphabet.dna()
diff --git a/pyhmmer/tests/test_plan7/test_tophits.py b/pyhmmer/tests/test_plan7/test_tophits.py
index c203190b..7a6ecfc4 100644
--- a/pyhmmer/tests/test_plan7/test_tophits.py
+++ b/pyhmmer/tests/test_plan7/test_tophits.py
@@ -83,9 +83,9 @@ class TestTopHits(unittest.TestCase):
         self.assertEqual(sum(d.included for d in h1.domains), sum(d.included for d in h2.domains))
 
     def assertHitsEqual(self, hits1, hits2):
-        self.assertEqual(hits1.query_name, hits2.query_name)
-        self.assertEqual(hits1.query_accession, hits2.query_accession)
-        self.assertEqual(hits1.query_length, hits2.query_length)
+        self.assertEqual(hits1.query.name, hits2.query.name)
+        self.assertEqual(hits1.query.accession, hits2.query.accession)
+        self.assertEqual(hits1.query, hits2.query)
         self.assertEqual(len(hits1), len(hits2))
         self.assertEqual(len(hits1.included), len(hits2.included))
         self.assertEqual(len(hits1.reported), len(hits2.reported))
@@ -180,9 +180,9 @@ class TestTopHits(unittest.TestCase):
         self.assertEqual(merged_empty.domZ, 0.0)
 
         merged = empty.merge(self.hits)
-        self.assertEqual(merged.query_name, self.hits.query_name)
-        self.assertEqual(merged.query_length, self.hits.query_length)
-        self.assertEqual(merged.query_accession, self.hits.query_accession)
+        self.assertEqual(merged.query.name, self.hits.query.name)
+        self.assertEqual(merged.query.M, self.hits.query.M)
+        self.assertEqual(merged.query.accession, self.hits.query.accession)
         self.assertEqual(merged.E, self.hits.E)
         self.assertHitsEqual(merged, self.hits)
 
@@ -206,9 +206,10 @@ class TestTopHits(unittest.TestCase):
         self.assertEqual(merged.searched_sequences, hits.searched_sequences)
         self.assertEqual(merged.Z, hits.Z)
         self.assertEqual(merged.domZ, hits.domZ)
-        self.assertEqual(merged.query_name, hits.query_name)
-        self.assertEqual(merged.query_length, hits.query_length)
-        self.assertEqual(merged.query_accession, hits.query_accession)
+        self.assertIs(merged.query, hits.query)
+        self.assertEqual(merged.query.name, hits.query.name)
+        self.assertEqual(merged.query.M, hits.query.M)
+        self.assertEqual(merged.query.accession, hits.query.accession)
         self.assertEqual(merged.E, hits.E)
         self.assertEqual(merged.domE, hits.domE)
 
@@ -244,9 +245,10 @@ class TestTopHits(unittest.TestCase):
         self.assertEqual(merged.searched_sequences, hits.searched_sequences)
         self.assertEqual(merged.Z, hits.Z)
         self.assertEqual(merged.domZ, hits.domZ)
-        self.assertEqual(merged.query_name, hits.query_name)
-        self.assertEqual(merged.query_length, hits.query_length)
-        self.assertEqual(merged.query_accession, hits.query_accession)
+        self.assertIs(merged.query, hits.query)
+        self.assertEqual(merged.query.name, hits.query.name)
+        self.assertEqual(merged.query.M, hits.query.M)
+        self.assertEqual(merged.query.accession, hits.query.accession)
         self.assertEqual(merged.E, hits.E)
         self.assertEqual(merged.domE, hits.domE)
 
@@ -340,14 +342,17 @@ class TestTopHits(unittest.TestCase):
         pickled = pickle.loads(pickle.dumps(self.hits))
         self.assertHitsEqual(pickled, self.hits)
 
+    def test_query(self):
+        self.assertIs(self.hits.query, self.hmm)
+
     def test_query_name(self):
-        self.assertEqual(self.hits.query_name, self.hmm.name)
+        self.assertEqual(self.hits.query.name, self.hmm.name)
 
     def test_query_accession(self):
-        self.assertEqual(self.hits.query_accession, self.hmm.accession)
+        self.assertEqual(self.hits.query.accession, self.hmm.accession)
 
     def test_query_length(self):
-        self.assertEqual(self.hits.query_length, self.hmm.M)
+        self.assertEqual(self.hits.query.M, self.hmm.M)
 
     def test_write_target(self):
         buffer = io.BytesIO()
-- 
GitLab