Newer
Older

Christopher Randolph Rhodes
committed
import pathlib

Christopher Randolph Rhodes
committed
from pydantic import BaseModel
from model_server.base.models import DummySemanticSegmentationModel
from model_server.base.session import Session
class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None:

Christopher Randolph Rhodes
committed
self.sesh = Session()
def tearDown(self) -> None:
print('Tearing down...')
Session._instances = {}
def test_session_is_singleton(self):

Christopher Randolph Rhodes
committed
Session._instances = {}
self.assertEqual(len(Session._instances), 0)
s = Session()
self.assertEqual(len(Session._instances), 1)
self.assertIs(s, Session())
self.assertEqual(len(Session._instances), 1)

Christopher Randolph Rhodes
committed
self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
def test_changing_session_root_creates_new_directory(self):
from model_server.conf.defaults import root

Christopher Randolph Rhodes
committed
old_paths = self.sesh.get_paths()
newroot = root / 'subdir'

Christopher Randolph Rhodes
committed
self.sesh.restart(root=newroot)
new_paths = self.sesh.get_paths()
for k in old_paths.keys():
self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))

Christopher Randolph Rhodes
committed
# this is necessary because logger itself is a singleton class
self.tearDown()
self.setUp()
rmtree(newroot)
self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
def test_change_session_subdirectory(self):

Christopher Randolph Rhodes
committed
old_paths = self.sesh.get_paths()

Christopher Randolph Rhodes
committed
self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
def test_restarting_session_creates_new_logfile(self):

Christopher Randolph Rhodes
committed
logfile1 = self.sesh.logfile

Christopher Randolph Rhodes
committed
self.sesh.restart()
logfile2 = self.sesh.logfile
self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')

Christopher Randolph Rhodes
committed
def test_log_warning(self):
msg = 'A test warning'

Christopher Randolph Rhodes
committed
self.sesh.log_info(msg)
with open(self.sesh.logfile, 'r') as fh:

Christopher Randolph Rhodes
committed
log = fh.read()
self.assertTrue(msg in log)
def test_get_logs(self):
self.sesh.log_info('Info example 1')
self.sesh.log_warning('Example warning')
self.sesh.log_info('Info example 2')
logs = self.sesh.get_log_data()
self.assertEqual(len(logs), 4)
self.assertEqual(logs[1]['level'], 'WARNING')
self.assertEqual(logs[-1]['message'], 'Initialized session')
def test_session_loads_model(self):
MC = DummySemanticSegmentationModel

Christopher Randolph Rhodes
committed
success = self.sesh.load_model(MC)

Christopher Randolph Rhodes
committed
self.assertTrue(success)

Christopher Randolph Rhodes
committed
loaded_models = self.sesh.describe_loaded_models()

Christopher Randolph Rhodes
committed
self.assertTrue(
(MC.__name__ + '_00') in loaded_models.keys()
)
self.assertEqual(
loaded_models[MC.__name__ + '_00']['class'],
MC.__name__

Christopher Randolph Rhodes
committed
)
def test_session_loads_second_instance_of_same_model(self):
MC = DummySemanticSegmentationModel

Christopher Randolph Rhodes
committed
self.sesh.load_model(MC)
self.sesh.load_model(MC)
self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())

Christopher Randolph Rhodes
committed
def test_session_loads_model_with_params(self):
MC = DummySemanticSegmentationModel

Christopher Randolph Rhodes
committed
class _PM(BaseModel):
p: str
p1 = _PM(p='abc')

Christopher Randolph Rhodes
committed
success = self.sesh.load_model(MC, params=p1)

Christopher Randolph Rhodes
committed
self.assertTrue(success)

Christopher Randolph Rhodes
committed
loaded_models = self.sesh.describe_loaded_models()

Christopher Randolph Rhodes
committed
mid = MC.__name__ + '_00'
self.assertEqual(loaded_models[mid]['params'], p1)
# load a second model and confirm that the first is locatable by its param entry

Christopher Randolph Rhodes
committed
p2 = _PM(p='def')

Christopher Randolph Rhodes
committed
self.sesh.load_model(MC, params=p2)

Christopher Randolph Rhodes
committed
find_mid = self.sesh.find_param_in_loaded_models('p', 'abc')

Christopher Randolph Rhodes
committed
self.assertEqual(mid, find_mid)

Christopher Randolph Rhodes
committed
self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)

Christopher Randolph Rhodes
committed
def test_session_finds_existing_model_with_different_path_formats(self):
MC = DummySemanticSegmentationModel

Christopher Randolph Rhodes
committed
class _PM(BaseModel):
path: str
mod_pw = _PM(path='c:\\windows\\dummy.pa')
mod_pu = _PM(path='c:/windows/dummy.pa')
mid = self.sesh.load_model(MC, params=mod_pw)
find_mid = self.sesh.find_param_in_loaded_models('path', mod_pu.path, is_path=True)

Christopher Randolph Rhodes
committed
self.assertEqual(mid, find_mid)

Christopher Randolph Rhodes
committed

Christopher Randolph Rhodes
committed
def test_change_output_path(self):

Christopher Randolph Rhodes
committed
pa = self.sesh.get_paths()['inbound_images']

Christopher Randolph Rhodes
committed
self.assertIsInstance(pa, pathlib.Path)

Christopher Randolph Rhodes
committed
self.sesh.set_data_directory('outbound_images', pa.__str__())
self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images'])
self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path)
def test_make_table(self):
import pandas as pd
data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)]
self.sesh.write_to_table(
'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4])
)
self.assertTrue(self.sesh.tables['test_numbers'].path.exists())
self.sesh.write_to_table(
'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8])
)
dfv = pd.read_csv(self.sesh.tables['test_numbers'].path)
self.assertEqual(len(dfv), len(data))
self.assertEqual(dfv.columns[0], 'X')
self.assertEqual(dfv.columns[1], 'Y')