Skip to content
Snippets Groups Projects
Commit 43f84ad3 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Cleared up issues with Singleton behavior, largely confused by the fact this...

Cleared up issues with Singleton behavior, largely confused by the fact this python logger itself is a Singleton that persists between tests
parent 0f1003c4
No related branches found
No related tags found
No related merge requests found
......@@ -13,18 +13,19 @@ from model_server.base.workflows import WorkflowRunRecord
def create_manifest_json():
pass
logger = logging.getLogger(__name__)
class Singleton(type):
_instances = {}
class Session(object):
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Session(object, metaclass=Singleton):
"""
Singleton class for a server session that persists data between API calls
"""
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(Session, cls).__new__(cls)
return cls.instance
def __init__(self, root: str = None):
print('Initializing session')
self.models = {} # model_id : model object
......@@ -32,7 +33,8 @@ class Session(object):
self.paths = self.make_paths(root)
self.logfile = self.paths['logs'] / f'session.log'
logging.basicConfig(filename=self.logfile, level=logging.INFO)
logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True)
self.logger = logging.getLogger(__name__)
self.log_info('Initialized session')
......@@ -83,7 +85,7 @@ class Session(object):
return f'{yyyymmdd}-{idx:04d}'
def log_info(self, msg):
logger.info(msg)
self.logger.info(msg)
# def log_event(self, event: str):
# """
......
......@@ -5,52 +5,68 @@ from model_server.base.session import Session
class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None:
pass
self.sesh = Session()
def test_singleton(self):
Session._instances = {}
self.assertEqual(len(Session._instances), 0)
s = Session()
self.assertEqual(len(Session._instances), 1)
print(Session._instances)
self.assertTrue(s.logfile.exists(), s.logfile)
def test_single_session_instance(self):
sesh = Session()
self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object')
self.assertIs(self.sesh, Session(), 'Re-initializing Session class returned a new object')
from os.path import exists
self.assertTrue(exists(sesh.logfile), 'Session did not create a log file in the correct place')
self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place')
self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
self.assertTrue(exists(self.sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place')
def tearDown(self) -> None:
print('Tearing down...')
Session._instances = {}
def test_changing_session_root_creates_new_directory(self):
from model_server.conf.defaults import root
from shutil import rmtree
sesh = Session()
old_paths = sesh.get_paths()
old_paths = self.sesh.get_paths()
newroot = root / 'subdir'
sesh.restart(root=newroot)
new_paths = sesh.get_paths()
self.sesh.restart(root=newroot)
new_paths = self.sesh.get_paths()
for k in old_paths.keys():
self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
# this is necessary because logger itself is a singleton class
self.tearDown()
self.setUp()
rmtree(newroot)
self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
def test_change_session_subdirectory(self):
sesh = Session()
old_paths = sesh.get_paths()
old_paths = self.sesh.get_paths()
print(old_paths)
sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images'])
self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
def test_restart_session(self):
sesh = Session()
logfile1 = sesh.logfile
sesh.restart()
logfile2 = sesh.logfile
logfile1 = self.sesh.logfile
self.sesh.restart()
logfile2 = self.sesh.logfile
self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
def test_call_session_singleton(self):
logfile1 = self.sesh.logfile
sesh2 = Session()
logfile2 = sesh2.logfile
self.assertEqual(logfile1, logfile2, 'Re-initializing session does not generate new logfile')
def test_log_warning(self):
# this is passing on its own but failing in conjunction with other tests
sesh = Session()
msg = 'A test warning'
sesh.log_info(msg)
self.sesh.log_info(msg)
with open(sesh.logfile, 'r') as fh:
with open(self.sesh.logfile, 'r') as fh:
log = fh.read()
self.assertTrue(msg in log)
......@@ -58,7 +74,6 @@ class TestGetSessionObject(unittest.TestCase):
def test_session_records_workflow(self):
import json
from model_server.base.workflows import WorkflowRunRecord
sesh = Session()
di = WorkflowRunRecord(
model_id='test_model',
input_filepath='/test/input/directory',
......@@ -66,18 +81,17 @@ class TestGetSessionObject(unittest.TestCase):
success=True,
timer_results={'start': 0.123},
)
sesh.record_workflow_run(di)
with open(sesh.manifest_json, 'r') as fh:
self.sesh.record_workflow_run(di)
with open(self.sesh.manifest_json, 'r') as fh:
do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_session_loads_model(self):
sesh = Session()
MC = DummySemanticSegmentationModel
success = sesh.load_model(MC)
success = self.sesh.load_model(MC)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
loaded_models = self.sesh.describe_loaded_models()
self.assertTrue(
(MC.__name__ + '_00') in loaded_models.keys()
)
......@@ -87,46 +101,42 @@ class TestGetSessionObject(unittest.TestCase):
)
def test_session_loads_second_instance_of_same_model(self):
sesh = Session()
MC = DummySemanticSegmentationModel
sesh.load_model(MC)
sesh.load_model(MC)
self.assertIn(MC.__name__ + '_00', sesh.models.keys())
self.assertIn(MC.__name__ + '_01', sesh.models.keys())
self.sesh.load_model(MC)
self.sesh.load_model(MC)
self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())
def test_session_loads_model_with_params(self):
sesh = Session()
MC = DummySemanticSegmentationModel
p1 = {'p1': 'abc'}
success = sesh.load_model(MC, params=p1)
success = self.sesh.load_model(MC, params=p1)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
loaded_models = self.sesh.describe_loaded_models()
mid = MC.__name__ + '_00'
self.assertEqual(loaded_models[mid]['params'], p1)
# load a second model and confirm that the first is locatable by its param entry
p2 = {'p2': 'def'}
sesh.load_model(MC, params=p2)
find_mid = sesh.find_param_in_loaded_models('p1', 'abc')
self.sesh.load_model(MC, params=p2)
find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc')
self.assertEqual(mid, find_mid)
self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1)
self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)
def test_session_finds_existing_model_with_different_path_formats(self):
sesh = Session()
MC = DummySemanticSegmentationModel
p1 = {'path': 'c:\\windows\\dummy.pa'}
p2 = {'path': 'c:/windows/dummy.pa'}
mid = sesh.load_model(MC, params=p1)
mid = self.sesh.load_model(MC, params=p1)
assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
self.assertEqual(mid, find_mid)
def test_change_output_path(self):
import pathlib
sesh = Session()
pa = sesh.get_paths()['inbound_images']
pa = self.sesh.get_paths()['inbound_images']
self.assertIsInstance(pa, pathlib.Path)
sesh.set_data_directory('outbound_images', pa.__str__())
self.assertEqual(sesh.paths['inbound_images'], sesh.paths['outbound_images'])
self.assertIsInstance(sesh.paths['outbound_images'], pathlib.Path)
\ No newline at end of file
self.sesh.set_data_directory('outbound_images', pa.__str__())
self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images'])
self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment