Skip to content
Snippets Groups Projects
test_session.py 5.7 KiB
Newer Older
from os.path import exists
import unittest
from base.models import DummySemanticSegmentationModel
from base.session import session

class TestGetSessionObject(unittest.TestCase):
    def setUp(self) -> None:
    def test_session_logfile_is_valid(self):
        self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')

    def test_changing_session_root_creates_new_directory(self):
        from model_server.conf.defaults import root
        self.sesh.restart(root=newroot)
        new_paths = self.sesh.get_paths()
        for k in old_paths.keys():
            self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
    def test_change_session_subdirectory(self):
        print(old_paths)
        self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
        self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
    def test_restarting_session_creates_new_logfile(self):
        self.assertTrue(logfile1.exists())
        self.assertTrue(logfile2.exists())
        self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
        session1 = self.sesh
        logfile1 = session1.logfile
        from base.session import session as session2
        self.assertEqual(session1, session2)
        logfile2 = session2.logfile
        self.assertTrue(logfile2.exists())
        self.assertEqual(logfile1, logfile2, 'Reimporting session incorrectly creates new logfile')

        from model_server.base.session import session as session3
        self.assertEqual(session1, session3)
        logfile3 = session3.logfile
        self.assertTrue(logfile3.exists())
        self.assertEqual(logfile1, logfile3, 'Reimporting session incorrectly creates new logfile')

        self.sesh.log_info(msg)
        with open(self.sesh.logfile, 'r') as fh:
    def test_get_logs(self):
        self.sesh.log_info('Info example 1')
        self.sesh.log_warning('Example warning')
        self.sesh.log_info('Info example 2')
        logs = self.sesh.get_log_data()
        self.assertEqual(len(logs), 4)
        self.assertEqual(logs[1]['level'], 'WARNING')
        self.assertEqual(logs[-1]['message'], 'Initialized session')

    def test_session_loads_model(self):
        MC = DummySemanticSegmentationModel
        loaded_models = self.sesh.describe_loaded_models()
        self.assertTrue(
            (MC.__name__ + '_00') in loaded_models.keys()
        )
            loaded_models[MC.__name__ + '_00']['class'],
            MC.__name__
        )

    def test_session_loads_second_instance_of_same_model(self):
        MC = DummySemanticSegmentationModel
        self.sesh.load_model(MC)
        self.sesh.load_model(MC)
        self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
        self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())
        MC = DummySemanticSegmentationModel
        class _PM(BaseModel):
            p: str
        p1 = _PM(p='abc')
        loaded_models = self.sesh.describe_loaded_models()
        mid = MC.__name__ + '_00'
        self.assertEqual(loaded_models[mid]['params'], p1)

        # load a second model and confirm that the first is locatable by its param entry
        find_mid = self.sesh.find_param_in_loaded_models('p', 'abc')
        self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)

    def test_session_finds_existing_model_with_different_path_formats(self):
        MC = DummySemanticSegmentationModel
        mod_pw = _PM(path='c:\\windows\\dummy.pa')
        mod_pu = _PM(path='c:/windows/dummy.pa')

        mid = self.sesh.load_model(MC, params=mod_pw)
        find_mid = self.sesh.find_param_in_loaded_models('path', mod_pu.path, is_path=True)
        self.sesh.set_data_directory('outbound_images', pa.__str__())
        self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images'])
        self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path)

    def test_make_table(self):
        import pandas as pd
        data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)]
        self.sesh.write_to_table(
            'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4])
        )
        self.assertTrue(self.sesh.tables['test_numbers'].path.exists())
        self.sesh.write_to_table(
            'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8])
        )

        dfv = pd.read_csv(self.sesh.tables['test_numbers'].path)
        self.assertEqual(len(dfv), len(data))
        self.assertEqual(dfv.columns[0], 'X')
        self.assertEqual(dfv.columns[1], 'Y')