Skip to content
Snippets Groups Projects
Commit 43f84ad3 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Cleared up issues with Singleton behavior, largely confused by the fact this...

Cleared up issues with Singleton behavior, largely confused by the fact this python logger itself is a Singleton that persists between tests
parent 0f1003c4
No related branches found
No related tags found
2 merge requests!16Completed (de)serialization of RoiSet,!9Session exposes a python log
...@@ -13,18 +13,19 @@ from model_server.base.workflows import WorkflowRunRecord ...@@ -13,18 +13,19 @@ from model_server.base.workflows import WorkflowRunRecord
def create_manifest_json(): def create_manifest_json():
pass pass
logger = logging.getLogger(__name__) class Singleton(type):
_instances = {}
class Session(object): def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Session(object, metaclass=Singleton):
""" """
Singleton class for a server session that persists data between API calls Singleton class for a server session that persists data between API calls
""" """
def __new__(cls):
if not hasattr(cls, 'instance'):
cls.instance = super(Session, cls).__new__(cls)
return cls.instance
def __init__(self, root: str = None): def __init__(self, root: str = None):
print('Initializing session') print('Initializing session')
self.models = {} # model_id : model object self.models = {} # model_id : model object
...@@ -32,7 +33,8 @@ class Session(object): ...@@ -32,7 +33,8 @@ class Session(object):
self.paths = self.make_paths(root) self.paths = self.make_paths(root)
self.logfile = self.paths['logs'] / f'session.log' self.logfile = self.paths['logs'] / f'session.log'
logging.basicConfig(filename=self.logfile, level=logging.INFO) logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True)
self.logger = logging.getLogger(__name__)
self.log_info('Initialized session') self.log_info('Initialized session')
...@@ -83,7 +85,7 @@ class Session(object): ...@@ -83,7 +85,7 @@ class Session(object):
return f'{yyyymmdd}-{idx:04d}' return f'{yyyymmdd}-{idx:04d}'
def log_info(self, msg): def log_info(self, msg):
logger.info(msg) self.logger.info(msg)
# def log_event(self, event: str): # def log_event(self, event: str):
# """ # """
......
...@@ -5,52 +5,68 @@ from model_server.base.session import Session ...@@ -5,52 +5,68 @@ from model_server.base.session import Session
class TestGetSessionObject(unittest.TestCase): class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
pass self.sesh = Session()
def test_singleton(self):
Session._instances = {}
self.assertEqual(len(Session._instances), 0)
s = Session()
self.assertEqual(len(Session._instances), 1)
print(Session._instances)
self.assertTrue(s.logfile.exists(), s.logfile)
def test_single_session_instance(self): def test_single_session_instance(self):
sesh = Session() self.assertIs(self.sesh, Session(), 'Re-initializing Session class returned a new object')
self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object')
from os.path import exists from os.path import exists
self.assertTrue(exists(sesh.logfile), 'Session did not create a log file in the correct place') self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place') self.assertTrue(exists(self.sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place')
def tearDown(self) -> None:
print('Tearing down...')
Session._instances = {}
def test_changing_session_root_creates_new_directory(self): def test_changing_session_root_creates_new_directory(self):
from model_server.conf.defaults import root from model_server.conf.defaults import root
from shutil import rmtree from shutil import rmtree
sesh = Session() old_paths = self.sesh.get_paths()
old_paths = sesh.get_paths()
newroot = root / 'subdir' newroot = root / 'subdir'
sesh.restart(root=newroot) self.sesh.restart(root=newroot)
new_paths = sesh.get_paths() new_paths = self.sesh.get_paths()
for k in old_paths.keys(): for k in old_paths.keys():
self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
# this is necessary because logger itself is a singleton class
self.tearDown()
self.setUp()
rmtree(newroot) rmtree(newroot)
self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory') self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
def test_change_session_subdirectory(self): def test_change_session_subdirectory(self):
sesh = Session() old_paths = self.sesh.get_paths()
old_paths = sesh.get_paths()
print(old_paths) print(old_paths)
sesh.set_data_directory('outbound_images', old_paths['inbound_images']) self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images']) self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
def test_restart_session(self): def test_restart_session(self):
sesh = Session() logfile1 = self.sesh.logfile
logfile1 = sesh.logfile self.sesh.restart()
sesh.restart() logfile2 = self.sesh.logfile
logfile2 = sesh.logfile
self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
def test_call_session_singleton(self):
logfile1 = self.sesh.logfile
sesh2 = Session()
logfile2 = sesh2.logfile
self.assertEqual(logfile1, logfile2, 'Re-initializing session does not generate new logfile')
def test_log_warning(self): def test_log_warning(self):
# this is passing on its own but failing in conjunction with other tests
sesh = Session()
msg = 'A test warning' msg = 'A test warning'
sesh.log_info(msg) self.sesh.log_info(msg)
with open(sesh.logfile, 'r') as fh: with open(self.sesh.logfile, 'r') as fh:
log = fh.read() log = fh.read()
self.assertTrue(msg in log) self.assertTrue(msg in log)
...@@ -58,7 +74,6 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -58,7 +74,6 @@ class TestGetSessionObject(unittest.TestCase):
def test_session_records_workflow(self): def test_session_records_workflow(self):
import json import json
from model_server.base.workflows import WorkflowRunRecord from model_server.base.workflows import WorkflowRunRecord
sesh = Session()
di = WorkflowRunRecord( di = WorkflowRunRecord(
model_id='test_model', model_id='test_model',
input_filepath='/test/input/directory', input_filepath='/test/input/directory',
...@@ -66,18 +81,17 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -66,18 +81,17 @@ class TestGetSessionObject(unittest.TestCase):
success=True, success=True,
timer_results={'start': 0.123}, timer_results={'start': 0.123},
) )
sesh.record_workflow_run(di) self.sesh.record_workflow_run(di)
with open(sesh.manifest_json, 'r') as fh: with open(self.sesh.manifest_json, 'r') as fh:
do = json.load(fh) do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct') self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_session_loads_model(self): def test_session_loads_model(self):
sesh = Session()
MC = DummySemanticSegmentationModel MC = DummySemanticSegmentationModel
success = sesh.load_model(MC) success = self.sesh.load_model(MC)
self.assertTrue(success) self.assertTrue(success)
loaded_models = sesh.describe_loaded_models() loaded_models = self.sesh.describe_loaded_models()
self.assertTrue( self.assertTrue(
(MC.__name__ + '_00') in loaded_models.keys() (MC.__name__ + '_00') in loaded_models.keys()
) )
...@@ -87,46 +101,42 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -87,46 +101,42 @@ class TestGetSessionObject(unittest.TestCase):
) )
def test_session_loads_second_instance_of_same_model(self): def test_session_loads_second_instance_of_same_model(self):
sesh = Session()
MC = DummySemanticSegmentationModel MC = DummySemanticSegmentationModel
sesh.load_model(MC) self.sesh.load_model(MC)
sesh.load_model(MC) self.sesh.load_model(MC)
self.assertIn(MC.__name__ + '_00', sesh.models.keys()) self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
self.assertIn(MC.__name__ + '_01', sesh.models.keys()) self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())
def test_session_loads_model_with_params(self): def test_session_loads_model_with_params(self):
sesh = Session()
MC = DummySemanticSegmentationModel MC = DummySemanticSegmentationModel
p1 = {'p1': 'abc'} p1 = {'p1': 'abc'}
success = sesh.load_model(MC, params=p1) success = self.sesh.load_model(MC, params=p1)
self.assertTrue(success) self.assertTrue(success)
loaded_models = sesh.describe_loaded_models() loaded_models = self.sesh.describe_loaded_models()
mid = MC.__name__ + '_00' mid = MC.__name__ + '_00'
self.assertEqual(loaded_models[mid]['params'], p1) self.assertEqual(loaded_models[mid]['params'], p1)
# load a second model and confirm that the first is locatable by its param entry # load a second model and confirm that the first is locatable by its param entry
p2 = {'p2': 'def'} p2 = {'p2': 'def'}
sesh.load_model(MC, params=p2) self.sesh.load_model(MC, params=p2)
find_mid = sesh.find_param_in_loaded_models('p1', 'abc') find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc')
self.assertEqual(mid, find_mid) self.assertEqual(mid, find_mid)
self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1) self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)
def test_session_finds_existing_model_with_different_path_formats(self): def test_session_finds_existing_model_with_different_path_formats(self):
sesh = Session()
MC = DummySemanticSegmentationModel MC = DummySemanticSegmentationModel
p1 = {'path': 'c:\\windows\\dummy.pa'} p1 = {'path': 'c:\\windows\\dummy.pa'}
p2 = {'path': 'c:/windows/dummy.pa'} p2 = {'path': 'c:/windows/dummy.pa'}
mid = sesh.load_model(MC, params=p1) mid = self.sesh.load_model(MC, params=p1)
assert pathlib.Path(p1['path']) == pathlib.Path(p2['path']) assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
self.assertEqual(mid, find_mid) self.assertEqual(mid, find_mid)
def test_change_output_path(self): def test_change_output_path(self):
import pathlib import pathlib
sesh = Session() pa = self.sesh.get_paths()['inbound_images']
pa = sesh.get_paths()['inbound_images']
self.assertIsInstance(pa, pathlib.Path) self.assertIsInstance(pa, pathlib.Path)
sesh.set_data_directory('outbound_images', pa.__str__()) self.sesh.set_data_directory('outbound_images', pa.__str__())
self.assertEqual(sesh.paths['inbound_images'], sesh.paths['outbound_images']) self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images'])
self.assertIsInstance(sesh.paths['outbound_images'], pathlib.Path) self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path)
\ No newline at end of file \ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment