Skip to content
Snippets Groups Projects
test_session.py 3.37 KiB
import unittest
from model_server.models import DummyImageToImageModel
from model_server.session import Session

class TestGetSessionObject(unittest.TestCase):
    def setUp(self) -> None:
        pass

    def test_single_session_instance(self):
        sesh = Session()
        self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object')

        from os.path import exists
        self.assertTrue(exists(sesh.session_log), 'Session did not create a log file in the correct place')
        self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place')

    def test_changing_session_root_creates_new_directory(self):
        from conf.defaults import root
        from shutil import rmtree

        sesh = Session()
        old_paths = sesh.get_paths()
        newroot = root / 'subdir'
        sesh.restart(root=newroot)
        new_paths = sesh.get_paths()
        for k in old_paths.keys():
            self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
        rmtree(newroot)
        self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')


    def test_restart_session(self):
        sesh = Session()
        logfile1 = sesh.session_log
        sesh.restart()
        logfile2 = sesh.session_log
        self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')

    def test_session_records_workflow(self):
        import json
        from model_server.workflows import WorkflowRunRecord
        sesh = Session()
        di = WorkflowRunRecord(
            model_id='test_model',
            input_filepath='/test/input/directory',
            output_filepath='/test/output/fi.le',
            success=True,
            timer_results={'start': 0.123},
        )
        sesh.record_workflow_run(di)
        with open(sesh.manifest_json, 'r') as fh:
            do = json.load(fh)
        self.assertEqual(di.dict(), do, 'Manifest record is not correct')


    def test_session_loads_model(self):
        sesh = Session()
        MC = DummyImageToImageModel
        success = sesh.load_model(MC)
        self.assertTrue(success)
        loaded_models = sesh.describe_loaded_models()
        self.assertTrue(
            (MC.__name__ + '_00') in loaded_models.keys()
        )
        self.assertEqual(
            loaded_models[MC.__name__ + '_00']['class'],
            MC.__name__
        )

    def test_session_loads_second_instance_of_same_model(self):
        sesh = Session()
        MC = DummyImageToImageModel
        sesh.load_model(MC)
        sesh.load_model(MC)
        self.assertIn(MC.__name__ + '_00', sesh.models.keys())
        self.assertIn(MC.__name__ + '_01', sesh.models.keys())


    def test_session_loads_model_with_params(self):
        sesh = Session()
        MC = DummyImageToImageModel
        p1 = {'p1': 'abc'}
        success = sesh.load_model(MC, params=p1)
        self.assertTrue(success)
        loaded_models = sesh.describe_loaded_models()
        mid = MC.__name__ + '_00'
        self.assertEqual(loaded_models[mid]['params'], p1)

        # load a second model and confirm that the first is locatable by its param entry
        p2 = {'p2': 'def'}
        sesh.load_model(MC, params=p2)
        find_kv = sesh.find_param_in_loaded_models('p1', 'abc')
        self.assertEqual(len(find_kv), 1)
        self.assertEqual(find_kv[mid]['params'], p1)