diff --git a/tests/test_session.py b/tests/test_session.py index 472c38a735eb6d2100d4672758b27d8eaebdfb75..6d998c331b14379c6d9be0c16e0a200ffee7e7a7 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,31 +1,32 @@ +import json +from os.path import exists import pathlib import unittest + from model_server.base.models import DummySemanticSegmentationModel from model_server.base.session import Session +from model_server.base.workflows import WorkflowRunRecord class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: self.sesh = Session() - def test_singleton(self): + def tearDown(self) -> None: + print('Tearing down...') + Session._instances = {} + + def test_session_is_singleton(self): Session._instances = {} self.assertEqual(len(Session._instances), 0) s = Session() self.assertEqual(len(Session._instances), 1) - print(Session._instances) - self.assertTrue(s.logfile.exists(), s.logfile) - - def test_single_session_instance(self): - self.assertIs(self.sesh, Session(), 'Re-initializing Session class returned a new object') + self.assertIs(s, Session()) + self.assertEqual(len(Session._instances), 1) - from os.path import exists + def test_session_logfile_is_valid(self): self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place') self.assertTrue(exists(self.sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place') - def tearDown(self) -> None: - print('Tearing down...') - Session._instances = {} - def test_changing_session_root_creates_new_directory(self): from model_server.conf.defaults import root from shutil import rmtree @@ -49,31 +50,22 @@ class TestGetSessionObject(unittest.TestCase): self.sesh.set_data_directory('outbound_images', old_paths['inbound_images']) self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images']) - - def test_restart_session(self): + def test_restarting_session_creates_new_logfile(self): logfile1 = self.sesh.logfile + self.assertTrue(logfile1.exists()) self.sesh.restart() logfile2 = self.sesh.logfile + self.assertTrue(logfile2.exists()) self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') - def test_call_session_singleton(self): - logfile1 = self.sesh.logfile - sesh2 = Session() - logfile2 = sesh2.logfile - self.assertEqual(logfile1, logfile2, 'Re-initializing session does not generate new logfile') - def test_log_warning(self): msg = 'A test warning' self.sesh.log_info(msg) - with open(self.sesh.logfile, 'r') as fh: log = fh.read() - self.assertTrue(msg in log) def test_session_records_workflow(self): - import json - from model_server.base.workflows import WorkflowRunRecord di = WorkflowRunRecord( model_id='test_model', input_filepath='/test/input/directory', @@ -86,7 +78,6 @@ class TestGetSessionObject(unittest.TestCase): do = json.load(fh) self.assertEqual(di.dict(), do, 'Manifest record is not correct') - def test_session_loads_model(self): MC = DummySemanticSegmentationModel success = self.sesh.load_model(MC) @@ -107,7 +98,6 @@ class TestGetSessionObject(unittest.TestCase): self.assertIn(MC.__name__ + '_00', self.sesh.models.keys()) self.assertIn(MC.__name__ + '_01', self.sesh.models.keys()) - def test_session_loads_model_with_params(self): MC = DummySemanticSegmentationModel p1 = {'p1': 'abc'} @@ -134,7 +124,6 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual(mid, find_mid) def test_change_output_path(self): - import pathlib pa = self.sesh.get_paths()['inbound_images'] self.assertIsInstance(pa, pathlib.Path) self.sesh.set_data_directory('outbound_images', pa.__str__())