import json from os.path import exists import pathlib from pydantic import BaseModel import unittest from model_server.base.models import DummySemanticSegmentationModel from model_server.base.session import Session from model_server.base.workflows import WorkflowRunRecord class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: self.sesh = Session() def tearDown(self) -> None: print('Tearing down...') Session._instances = {} def test_session_is_singleton(self): Session._instances = {} self.assertEqual(len(Session._instances), 0) s = Session() self.assertEqual(len(Session._instances), 1) self.assertIs(s, Session()) self.assertEqual(len(Session._instances), 1) def test_session_logfile_is_valid(self): self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place') def test_changing_session_root_creates_new_directory(self): from model_server.conf.defaults import root from shutil import rmtree old_paths = self.sesh.get_paths() newroot = root / 'subdir' self.sesh.restart(root=newroot) new_paths = self.sesh.get_paths() for k in old_paths.keys(): self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) # this is necessary because logger itself is a singleton class self.tearDown() self.setUp() rmtree(newroot) self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory') def test_change_session_subdirectory(self): old_paths = self.sesh.get_paths() print(old_paths) self.sesh.set_data_directory('outbound_images', old_paths['inbound_images']) self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images']) def test_restarting_session_creates_new_logfile(self): logfile1 = self.sesh.logfile self.assertTrue(logfile1.exists()) self.sesh.restart() logfile2 = self.sesh.logfile self.assertTrue(logfile2.exists()) self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile') def test_log_warning(self): msg = 'A test warning' self.sesh.log_info(msg) with open(self.sesh.logfile, 'r') as fh: log = fh.read() self.assertTrue(msg in log) def test_get_logs(self): self.sesh.log_info('Info example 1') self.sesh.log_warning('Example warning') self.sesh.log_info('Info example 2') logs = self.sesh.get_log_data() self.assertEqual(len(logs), 4) self.assertEqual(logs[1]['level'], 'WARNING') self.assertEqual(logs[-1]['message'], 'Initialized session') def test_session_loads_model(self): MC = DummySemanticSegmentationModel success = self.sesh.load_model(MC) self.assertTrue(success) loaded_models = self.sesh.describe_loaded_models() self.assertTrue( (MC.__name__ + '_00') in loaded_models.keys() ) self.assertEqual( loaded_models[MC.__name__ + '_00']['class'], MC.__name__ ) def test_session_loads_second_instance_of_same_model(self): MC = DummySemanticSegmentationModel self.sesh.load_model(MC) self.sesh.load_model(MC) self.assertIn(MC.__name__ + '_00', self.sesh.models.keys()) self.assertIn(MC.__name__ + '_01', self.sesh.models.keys()) def test_session_loads_model_with_params(self): MC = DummySemanticSegmentationModel class _PM(BaseModel): p: str p1 = _PM(p='abc') success = self.sesh.load_model(MC, params=p1) self.assertTrue(success) loaded_models = self.sesh.describe_loaded_models() mid = MC.__name__ + '_00' self.assertEqual(loaded_models[mid]['params'], p1) # load a second model and confirm that the first is locatable by its param entry p2 = _PM(p='def') self.sesh.load_model(MC, params=p2) find_mid = self.sesh.find_param_in_loaded_models('p', 'abc') self.assertEqual(mid, find_mid) self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1) def test_session_finds_existing_model_with_different_path_formats(self): MC = DummySemanticSegmentationModel class _PM(BaseModel): path: str p1 = _PM(path='c:\\windows\\dummy.pa') p2 = _PM(path='c:/windows/dummy.pa') mid = self.sesh.load_model(MC, params=p1) assert pathlib.Path(p1.path) == pathlib.Path(p2.path) find_mid = self.sesh.find_param_in_loaded_models('path', p2.path, is_path=True) self.assertEqual(mid, find_mid) def test_change_output_path(self): pa = self.sesh.get_paths()['inbound_images'] self.assertIsInstance(pa, pathlib.Path) self.sesh.set_data_directory('outbound_images', pa.__str__()) self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images']) self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path) def test_make_table(self): import pandas as pd data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)] self.sesh.write_to_table( 'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4]) ) self.assertTrue(self.sesh.tables['test_numbers'].path.exists()) self.sesh.write_to_table( 'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8]) ) dfv = pd.read_csv(self.sesh.tables['test_numbers'].path) self.assertEqual(len(dfv), len(data)) self.assertEqual(dfv.columns[0], 'X') self.assertEqual(dfv.columns[1], 'Y')