Skip to content
Snippets Groups Projects
test_session.py 5.72 KiB
Newer Older
import json
from os.path import exists
import unittest
from model_server.base.models import DummySemanticSegmentationModel
from model_server.base.session import Session
from model_server.base.workflows import WorkflowRunRecord

class TestGetSessionObject(unittest.TestCase):
    def setUp(self) -> None:
    def tearDown(self) -> None:
        print('Tearing down...')
        Session._instances = {}

    def test_session_is_singleton(self):
        Session._instances = {}
        self.assertEqual(len(Session._instances), 0)
        s = Session()
        self.assertEqual(len(Session._instances), 1)
        self.assertIs(s, Session())
        self.assertEqual(len(Session._instances), 1)
    def test_session_logfile_is_valid(self):
        self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')

    def test_changing_session_root_creates_new_directory(self):
        from model_server.conf.defaults import root
        from shutil import rmtree

        self.sesh.restart(root=newroot)
        new_paths = self.sesh.get_paths()
        for k in old_paths.keys():
            self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))

        # this is necessary because logger itself is a singleton class
        self.tearDown()
        self.setUp()
        rmtree(newroot)
        self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
    def test_change_session_subdirectory(self):
        print(old_paths)
        self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
        self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
    def test_restarting_session_creates_new_logfile(self):
        self.assertTrue(logfile1.exists())
        self.assertTrue(logfile2.exists())
        self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
        self.sesh.log_info(msg)
        with open(self.sesh.logfile, 'r') as fh:
    def test_get_logs(self):
        self.sesh.log_info('Info example 1')
        self.sesh.log_warning('Example warning')
        self.sesh.log_info('Info example 2')
        logs = self.sesh.get_log_data()
        self.assertEqual(len(logs), 4)
        self.assertEqual(logs[1]['level'], 'WARNING')
        self.assertEqual(logs[-1]['message'], 'Initialized session')

    def test_session_loads_model(self):
        MC = DummySemanticSegmentationModel
        loaded_models = self.sesh.describe_loaded_models()
        self.assertTrue(
            (MC.__name__ + '_00') in loaded_models.keys()
        )
            loaded_models[MC.__name__ + '_00']['class'],
            MC.__name__
        )

    def test_session_loads_second_instance_of_same_model(self):
        MC = DummySemanticSegmentationModel
        self.sesh.load_model(MC)
        self.sesh.load_model(MC)
        self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
        self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())
        MC = DummySemanticSegmentationModel
        class _PM(BaseModel):
            p: str
        p1 = _PM(p='abc')
        loaded_models = self.sesh.describe_loaded_models()
        mid = MC.__name__ + '_00'
        self.assertEqual(loaded_models[mid]['params'], p1)

        # load a second model and confirm that the first is locatable by its param entry
        find_mid = self.sesh.find_param_in_loaded_models('p', 'abc')
        self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)

    def test_session_finds_existing_model_with_different_path_formats(self):
        MC = DummySemanticSegmentationModel
        class _PM(BaseModel):
            path: str

        p1 = _PM(path='c:\\windows\\dummy.pa')
        p2 = _PM(path='c:/windows/dummy.pa')
        assert pathlib.Path(p1.path) == pathlib.Path(p2.path)
        find_mid = self.sesh.find_param_in_loaded_models('path', p2.path, is_path=True)
        self.sesh.set_data_directory('outbound_images', pa.__str__())
        self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images'])
        self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path)

    def test_make_table(self):
        import pandas as pd
        data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)]
        self.sesh.write_to_table(
            'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4])
        )
        self.assertTrue(self.sesh.tables['test_numbers'].path.exists())
        self.sesh.write_to_table(
            'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8])
        )

        dfv = pd.read_csv(self.sesh.tables['test_numbers'].path)
        self.assertEqual(len(dfv), len(data))
        self.assertEqual(dfv.columns[0], 'X')
        self.assertEqual(dfv.columns[1], 'Y')