diff --git a/model_server/base/session.py b/model_server/base/session.py index ca119b546516259db2daa0c041d2ffbddddb8856..a72dfa5a9568d1fb7f4b1d0eac313f3c77a2bf72 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -13,18 +13,19 @@ from model_server.base.workflows import WorkflowRunRecord def create_manifest_json(): pass -logger = logging.getLogger(__name__) +class Singleton(type): + _instances = {} -class Session(object): + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + +class Session(object, metaclass=Singleton): """ Singleton class for a server session that persists data between API calls """ - def __new__(cls): - if not hasattr(cls, 'instance'): - cls.instance = super(Session, cls).__new__(cls) - return cls.instance - def __init__(self, root: str = None): print('Initializing session') self.models = {} # model_id : model object @@ -32,7 +33,8 @@ class Session(object): self.paths = self.make_paths(root) self.logfile = self.paths['logs'] / f'session.log' - logging.basicConfig(filename=self.logfile, level=logging.INFO) + logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True) + self.logger = logging.getLogger(__name__) self.log_info('Initialized session') @@ -83,7 +85,7 @@ class Session(object): return f'{yyyymmdd}-{idx:04d}' def log_info(self, msg): - logger.info(msg) + self.logger.info(msg) # def log_event(self, event: str): # """ diff --git a/tests/test_session.py b/tests/test_session.py index 92f1232100c70b67c43f69e63712f140ce5bbbfb..472c38a735eb6d2100d4672758b27d8eaebdfb75 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -5,52 +5,68 @@ from model_server.base.session import Session class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - pass + self.sesh = Session() + + def test_singleton(self): + Session._instances = {} + self.assertEqual(len(Session._instances), 0) + s = Session() + self.assertEqual(len(Session._instances), 1) + print(Session._instances) + self.assertTrue(s.logfile.exists(), s.logfile) def test_single_session_instance(self): - sesh = Session() - self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object') + self.assertIs(self.sesh, Session(), 'Re-initializing Session class returned a new object') from os.path import exists - self.assertTrue(exists(sesh.logfile), 'Session did not create a log file in the correct place') - self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place') + self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place') + self.assertTrue(exists(self.sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place') + + def tearDown(self) -> None: + print('Tearing down...') + Session._instances = {} def test_changing_session_root_creates_new_directory(self): from model_server.conf.defaults import root from shutil import rmtree - sesh = Session() - old_paths = sesh.get_paths() + old_paths = self.sesh.get_paths() newroot = root / 'subdir' - sesh.restart(root=newroot) - new_paths = sesh.get_paths() + self.sesh.restart(root=newroot) + new_paths = self.sesh.get_paths() for k in old_paths.keys(): self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) + + # this is necessary because logger itself is a singleton class + self.tearDown() + self.setUp() rmtree(newroot) self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory') def test_change_session_subdirectory(self): - sesh = Session() - old_paths = sesh.get_paths() + old_paths = self.sesh.get_paths() print(old_paths) - sesh.set_data_directory('outbound_images', old_paths['inbound_images']) - self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images']) + self.sesh.set_data_directory('outbound_images', old_paths['inbound_images']) + self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images']) def test_restart_session(self): - sesh = Session() - logfile1 = sesh.logfile - sesh.restart() - logfile2 = sesh.logfile + logfile1 = self.sesh.logfile + self.sesh.restart() + logfile2 = self.sesh.logfile self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') + def test_call_session_singleton(self): + logfile1 = self.sesh.logfile + sesh2 = Session() + logfile2 = sesh2.logfile + self.assertEqual(logfile1, logfile2, 'Re-initializing session does not generate new logfile') + def test_log_warning(self): - # this is passing on its own but failing in conjunction with other tests - sesh = Session() msg = 'A test warning' - sesh.log_info(msg) + self.sesh.log_info(msg) - with open(sesh.logfile, 'r') as fh: + with open(self.sesh.logfile, 'r') as fh: log = fh.read() self.assertTrue(msg in log) @@ -58,7 +74,6 @@ class TestGetSessionObject(unittest.TestCase): def test_session_records_workflow(self): import json from model_server.base.workflows import WorkflowRunRecord - sesh = Session() di = WorkflowRunRecord( model_id='test_model', input_filepath='/test/input/directory', @@ -66,18 +81,17 @@ class TestGetSessionObject(unittest.TestCase): success=True, timer_results={'start': 0.123}, ) - sesh.record_workflow_run(di) - with open(sesh.manifest_json, 'r') as fh: + self.sesh.record_workflow_run(di) + with open(self.sesh.manifest_json, 'r') as fh: do = json.load(fh) self.assertEqual(di.dict(), do, 'Manifest record is not correct') def test_session_loads_model(self): - sesh = Session() MC = DummySemanticSegmentationModel - success = sesh.load_model(MC) + success = self.sesh.load_model(MC) self.assertTrue(success) - loaded_models = sesh.describe_loaded_models() + loaded_models = self.sesh.describe_loaded_models() self.assertTrue( (MC.__name__ + '_00') in loaded_models.keys() ) @@ -87,46 +101,42 @@ class TestGetSessionObject(unittest.TestCase): ) def test_session_loads_second_instance_of_same_model(self): - sesh = Session() MC = DummySemanticSegmentationModel - sesh.load_model(MC) - sesh.load_model(MC) - self.assertIn(MC.__name__ + '_00', sesh.models.keys()) - self.assertIn(MC.__name__ + '_01', sesh.models.keys()) + self.sesh.load_model(MC) + self.sesh.load_model(MC) + self.assertIn(MC.__name__ + '_00', self.sesh.models.keys()) + self.assertIn(MC.__name__ + '_01', self.sesh.models.keys()) def test_session_loads_model_with_params(self): - sesh = Session() MC = DummySemanticSegmentationModel p1 = {'p1': 'abc'} - success = sesh.load_model(MC, params=p1) + success = self.sesh.load_model(MC, params=p1) self.assertTrue(success) - loaded_models = sesh.describe_loaded_models() + loaded_models = self.sesh.describe_loaded_models() mid = MC.__name__ + '_00' self.assertEqual(loaded_models[mid]['params'], p1) # load a second model and confirm that the first is locatable by its param entry p2 = {'p2': 'def'} - sesh.load_model(MC, params=p2) - find_mid = sesh.find_param_in_loaded_models('p1', 'abc') + self.sesh.load_model(MC, params=p2) + find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc') self.assertEqual(mid, find_mid) - self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1) + self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1) def test_session_finds_existing_model_with_different_path_formats(self): - sesh = Session() MC = DummySemanticSegmentationModel p1 = {'path': 'c:\\windows\\dummy.pa'} p2 = {'path': 'c:/windows/dummy.pa'} - mid = sesh.load_model(MC, params=p1) + mid = self.sesh.load_model(MC, params=p1) assert pathlib.Path(p1['path']) == pathlib.Path(p2['path']) - find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) + find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) self.assertEqual(mid, find_mid) def test_change_output_path(self): import pathlib - sesh = Session() - pa = sesh.get_paths()['inbound_images'] + pa = self.sesh.get_paths()['inbound_images'] self.assertIsInstance(pa, pathlib.Path) - sesh.set_data_directory('outbound_images', pa.__str__()) - self.assertEqual(sesh.paths['inbound_images'], sesh.paths['outbound_images']) - self.assertIsInstance(sesh.paths['outbound_images'], pathlib.Path) \ No newline at end of file + self.sesh.set_data_directory('outbound_images', pa.__str__()) + self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images']) + self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path) \ No newline at end of file