From 5518c337031ae92702a3304e987d92b4e7af1df9 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 22 Mar 2024 14:33:29 +0100
Subject: [PATCH] Cleaned up tests

---
 tests/test_session.py | 41 +++++++++++++++--------------------------
 1 file changed, 15 insertions(+), 26 deletions(-)

diff --git a/tests/test_session.py b/tests/test_session.py
index 472c38a7..6d998c33 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -1,31 +1,32 @@
+import json
+from os.path import exists
 import pathlib
 import unittest
+
 from model_server.base.models import DummySemanticSegmentationModel
 from model_server.base.session import Session
+from model_server.base.workflows import WorkflowRunRecord
 
 class TestGetSessionObject(unittest.TestCase):
     def setUp(self) -> None:
         self.sesh = Session()
 
-    def test_singleton(self):
+    def tearDown(self) -> None:
+        print('Tearing down...')
+        Session._instances = {}
+
+    def test_session_is_singleton(self):
         Session._instances = {}
         self.assertEqual(len(Session._instances), 0)
         s = Session()
         self.assertEqual(len(Session._instances), 1)
-        print(Session._instances)
-        self.assertTrue(s.logfile.exists(), s.logfile)
-
-    def test_single_session_instance(self):
-        self.assertIs(self.sesh, Session(), 'Re-initializing Session class returned a new object')
+        self.assertIs(s, Session())
+        self.assertEqual(len(Session._instances), 1)
 
-        from os.path import exists
+    def test_session_logfile_is_valid(self):
         self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
         self.assertTrue(exists(self.sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place')
 
-    def tearDown(self) -> None:
-        print('Tearing down...')
-        Session._instances = {}
-
     def test_changing_session_root_creates_new_directory(self):
         from model_server.conf.defaults import root
         from shutil import rmtree
@@ -49,31 +50,22 @@ class TestGetSessionObject(unittest.TestCase):
         self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
         self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
 
-
-    def test_restart_session(self):
+    def test_restarting_session_creates_new_logfile(self):
         logfile1 = self.sesh.logfile
+        self.assertTrue(logfile1.exists())
         self.sesh.restart()
         logfile2 = self.sesh.logfile
+        self.assertTrue(logfile2.exists())
         self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
 
-    def test_call_session_singleton(self):
-        logfile1 = self.sesh.logfile
-        sesh2 = Session()
-        logfile2 = sesh2.logfile
-        self.assertEqual(logfile1, logfile2, 'Re-initializing session does not generate new logfile')
-
     def test_log_warning(self):
         msg = 'A test warning'
         self.sesh.log_info(msg)
-
         with open(self.sesh.logfile, 'r') as fh:
             log = fh.read()
-
         self.assertTrue(msg in log)
 
     def test_session_records_workflow(self):
-        import json
-        from model_server.base.workflows import WorkflowRunRecord
         di = WorkflowRunRecord(
             model_id='test_model',
             input_filepath='/test/input/directory',
@@ -86,7 +78,6 @@ class TestGetSessionObject(unittest.TestCase):
             do = json.load(fh)
         self.assertEqual(di.dict(), do, 'Manifest record is not correct')
 
-
     def test_session_loads_model(self):
         MC = DummySemanticSegmentationModel
         success = self.sesh.load_model(MC)
@@ -107,7 +98,6 @@ class TestGetSessionObject(unittest.TestCase):
         self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
         self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())
 
-
     def test_session_loads_model_with_params(self):
         MC = DummySemanticSegmentationModel
         p1 = {'p1': 'abc'}
@@ -134,7 +124,6 @@ class TestGetSessionObject(unittest.TestCase):
         self.assertEqual(mid, find_mid)
 
     def test_change_output_path(self):
-        import pathlib
         pa = self.sesh.get_paths()['inbound_images']
         self.assertIsInstance(pa, pathlib.Path)
         self.sesh.set_data_directory('outbound_images', pa.__str__())
-- 
GitLab