import pathlib import unittest from model_server.base.models import DummySemanticSegmentationModel from model_server.base.session import Session class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: pass def test_single_session_instance(self): sesh = Session() self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object') from os.path import exists self.assertTrue(exists(sesh.logfile), 'Session did not create a log file in the correct place') self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place') def test_changing_session_root_creates_new_directory(self): from model_server.conf.defaults import root from shutil import rmtree sesh = Session() old_paths = sesh.get_paths() newroot = root / 'subdir' sesh.restart(root=newroot) new_paths = sesh.get_paths() for k in old_paths.keys(): self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) rmtree(newroot) self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory') def test_change_session_subdirectory(self): sesh = Session() old_paths = sesh.get_paths() print(old_paths) sesh.set_data_directory('outbound_images', old_paths['inbound_images']) self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images']) def test_restart_session(self): sesh = Session() logfile1 = sesh.logfile sesh.restart() logfile2 = sesh.logfile self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') def test_log_warning(self): # this is passing on its own but failing in conjunction with other tests sesh = Session() msg = 'A test warning' sesh.log_info(msg) with open(sesh.logfile, 'r') as fh: log = fh.read() self.assertTrue(msg in log) def test_session_records_workflow(self): import json from model_server.base.workflows import WorkflowRunRecord sesh = Session() di = WorkflowRunRecord( model_id='test_model', input_filepath='/test/input/directory', output_filepath='/test/output/fi.le', success=True, timer_results={'start': 0.123}, ) sesh.record_workflow_run(di) with open(sesh.manifest_json, 'r') as fh: do = json.load(fh) self.assertEqual(di.dict(), do, 'Manifest record is not correct') def test_session_loads_model(self): sesh = Session() MC = DummySemanticSegmentationModel success = sesh.load_model(MC) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() self.assertTrue( (MC.__name__ + '_00') in loaded_models.keys() ) self.assertEqual( loaded_models[MC.__name__ + '_00']['class'], MC.__name__ ) def test_session_loads_second_instance_of_same_model(self): sesh = Session() MC = DummySemanticSegmentationModel sesh.load_model(MC) sesh.load_model(MC) self.assertIn(MC.__name__ + '_00', sesh.models.keys()) self.assertIn(MC.__name__ + '_01', sesh.models.keys()) def test_session_loads_model_with_params(self): sesh = Session() MC = DummySemanticSegmentationModel p1 = {'p1': 'abc'} success = sesh.load_model(MC, params=p1) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() mid = MC.__name__ + '_00' self.assertEqual(loaded_models[mid]['params'], p1) # load a second model and confirm that the first is locatable by its param entry p2 = {'p2': 'def'} sesh.load_model(MC, params=p2) find_mid = sesh.find_param_in_loaded_models('p1', 'abc') self.assertEqual(mid, find_mid) self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1) def test_session_finds_existing_model_with_different_path_formats(self): sesh = Session() MC = DummySemanticSegmentationModel p1 = {'path': 'c:\\windows\\dummy.pa'} p2 = {'path': 'c:/windows/dummy.pa'} mid = sesh.load_model(MC, params=p1) assert pathlib.Path(p1['path']) == pathlib.Path(p2['path']) find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) self.assertEqual(mid, find_mid) def test_change_output_path(self): import pathlib sesh = Session() pa = sesh.get_paths()['inbound_images'] self.assertIsInstance(pa, pathlib.Path) sesh.set_data_directory('outbound_images', pa.__str__()) self.assertEqual(sesh.paths['inbound_images'], sesh.paths['outbound_images']) self.assertIsInstance(sesh.paths['outbound_images'], pathlib.Path)