-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
test_session.py 5.33 KiB
from os.path import exists
import pathlib
from pydantic import BaseModel
import unittest
import base.session
from base.models import DummySemanticSegmentationModel
from base.session import session
class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None:
session.restart()
self.sesh = session
def test_session_logfile_is_valid(self):
self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
def test_changing_session_root_creates_new_directory(self):
from model_server.conf.defaults import root
old_paths = self.sesh.get_paths()
newroot = root / 'subdir'
self.sesh.restart(root=newroot)
new_paths = self.sesh.get_paths()
for k in old_paths.keys():
self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
def test_change_session_subdirectory(self):
old_paths = self.sesh.get_paths()
print(old_paths)
self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
def test_restarting_session_creates_new_logfile(self):
logfile1 = self.sesh.logfile
self.assertTrue(logfile1.exists())
self.sesh.restart()
logfile2 = self.sesh.logfile
self.assertTrue(logfile2.exists())
self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
def test_reimporting_session_uses_same_logfile(self):
logfile1 = self.sesh.logfile
self.assertTrue(logfile1.exists())
session2 = base.session.session
logfile2 = session2.logfile
self.assertTrue(logfile2.exists())
self.assertEqual(logfile1, logfile2, 'Reimporting session incorrectly creates new logfile')
def test_log_warning(self):
msg = 'A test warning'
self.sesh.log_info(msg)
with open(self.sesh.logfile, 'r') as fh:
log = fh.read()
self.assertTrue(msg in log)
def test_get_logs(self):
self.sesh.log_info('Info example 1')
self.sesh.log_warning('Example warning')
self.sesh.log_info('Info example 2')
logs = self.sesh.get_log_data()
self.assertEqual(len(logs), 4)
self.assertEqual(logs[1]['level'], 'WARNING')
self.assertEqual(logs[-1]['message'], 'Initialized session')
def test_session_loads_model(self):
MC = DummySemanticSegmentationModel
success = self.sesh.load_model(MC)
self.assertTrue(success)
loaded_models = self.sesh.describe_loaded_models()
self.assertTrue(
(MC.__name__ + '_00') in loaded_models.keys()
)
self.assertEqual(
loaded_models[MC.__name__ + '_00']['class'],
MC.__name__
)
def test_session_loads_second_instance_of_same_model(self):
MC = DummySemanticSegmentationModel
self.sesh.load_model(MC)
self.sesh.load_model(MC)
self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())
def test_session_loads_model_with_params(self):
MC = DummySemanticSegmentationModel
class _PM(BaseModel):
p: str
p1 = _PM(p='abc')
success = self.sesh.load_model(MC, params=p1)
self.assertTrue(success)
loaded_models = self.sesh.describe_loaded_models()
mid = MC.__name__ + '_00'
self.assertEqual(loaded_models[mid]['params'], p1)
# load a second model and confirm that the first is locatable by its param entry
p2 = _PM(p='def')
self.sesh.load_model(MC, params=p2)
find_mid = self.sesh.find_param_in_loaded_models('p', 'abc')
self.assertEqual(mid, find_mid)
self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)
def test_session_finds_existing_model_with_different_path_formats(self):
MC = DummySemanticSegmentationModel
class _PM(BaseModel):
path: str
mod_pw = _PM(path='c:\\windows\\dummy.pa')
mod_pu = _PM(path='c:/windows/dummy.pa')
mid = self.sesh.load_model(MC, params=mod_pw)
find_mid = self.sesh.find_param_in_loaded_models('path', mod_pu.path, is_path=True)
self.assertEqual(mid, find_mid)
def test_change_output_path(self):
pa = self.sesh.get_paths()['inbound_images']
self.assertIsInstance(pa, pathlib.Path)
self.sesh.set_data_directory('outbound_images', pa.__str__())
self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images'])
self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path)
def test_make_table(self):
import pandas as pd
data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)]
self.sesh.write_to_table(
'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4])
)
self.assertTrue(self.sesh.tables['test_numbers'].path.exists())
self.sesh.write_to_table(
'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8])
)
dfv = pd.read_csv(self.sesh.tables['test_numbers'].path)
self.assertEqual(len(dfv), len(data))
self.assertEqual(dfv.columns[0], 'X')
self.assertEqual(dfv.columns[1], 'Y')