From 141e977572425e489daba1e1412b286225874ca6 Mon Sep 17 00:00:00 2001
From: Christopher Randolph Rhodes <christopher.rhodes@embl.de>
Date: Wed, 3 Apr 2024 14:44:07 +0200
Subject: [PATCH] Rebuilding release branch to resolve conflicts where master
 is ahead of staging

---
 model_server/base/api.py      |  17 +++-
 model_server/base/session.py  | 113 +++++++++++++++++++-------
 model_server/conf/defaults.py |   1 +
 model_server/conf/testing.py  |   1 +
 tests/test_api.py             |   9 ++-
 tests/test_session.py         | 145 +++++++++++++++++++---------------
 6 files changed, 188 insertions(+), 98 deletions(-)

diff --git a/model_server/base/api.py b/model_server/base/api.py
index 987cc381..13abef8e 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -51,6 +51,7 @@ def change_path(key, path):
             status_code=404,
             detail=e.__str__(),
         )
+    session.log_info(f'Change {key} path to {path}')
     return session.get_paths()
 
 @app.put('/paths/watch_input')
@@ -61,22 +62,30 @@ def watch_input_path(path: str):
 def watch_output_path(path: str):
     return change_path('outbound_images', path)
 
-@app.get('/restart')
+@app.get('/session/restart')
 def restart_session(root: str = None) -> dict:
     session.restart(root=root)
     return session.describe_loaded_models()
 
+@app.get('/session/logs')
+def list_session_log() -> list:
+    return session.get_log_data()
+
 @app.get('/models')
 def list_active_models():
     return session.describe_loaded_models()
 
 @app.put('/models/dummy_semantic/load/')
 def load_dummy_model() -> dict:
-    return {'model_id': session.load_model(DummySemanticSegmentationModel)}
+    mid = session.load_model(DummySemanticSegmentationModel)
+    session.log_info(f'Loaded model {mid}')
+    return {'model_id': mid}
 
 @app.put('/models/dummy_instance/load/')
 def load_dummy_model() -> dict:
-    return {'model_id': session.load_model(DummyInstanceSegmentationModel)}
+    mid = session.load_model(DummyInstanceSegmentationModel)
+    session.log_info(f'Loaded model {mid}')
+    return {'model_id': mid}
 
 @app.put('/workflows/segment')
 def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
@@ -88,5 +97,5 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
         session.paths['outbound_images'],
         channel=channel,
     )
-    session.record_workflow_run(record)
+    session.log_info(f'Completed segmentation of {input_filename}')
     return record
\ No newline at end of file
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 1b91b12a..3d5b789d 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -1,36 +1,82 @@
-import json
+import logging
 import os
 
 from pathlib import Path
 from time import strftime, localtime
 from typing import Dict
 
+import pandas as pd
+
 import model_server.conf.defaults
 from model_server.base.models import Model
-from model_server.base.workflows import WorkflowRunRecord
 
-def create_manifest_json():
-    pass
+logger = logging.getLogger(__name__)
+
+
+class Singleton(type):
+    _instances = {}
+
+    def __call__(cls, *args, **kwargs):
+        if cls not in cls._instances:
+            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
+        return cls._instances[cls]
+
+class CsvTable(object):
+    def __init__(self, fpath: Path):
+        self.path = fpath
+        self.empty = True
 
-class Session(object):
+    def append(self, coords: dict, data: pd.DataFrame) -> bool:
+        assert isinstance(data, pd.DataFrame)
+        for c in reversed(coords.keys()):
+            data.insert(0, c, coords[c])
+        if self.empty:
+            data.to_csv(self.path, index=False, mode='w', header=True)
+        else:
+            data.to_csv(self.path, index=False, mode='a', header=False)
+        self.empty = False
+        return True
+
+class Session(object, metaclass=Singleton):
     """
     Singleton class for a server session that persists data between API calls
     """
 
-    def __new__(cls):
-        if not hasattr(cls, 'instance'):
-            cls.instance = super(Session, cls).__new__(cls)
-        return cls.instance
+    log_format = '%(asctime)s - %(levelname)s - %(message)s'
 
     def __init__(self, root: str = None):
         print('Initializing session')
         self.models = {} # model_id : model object
-        self.manifest = [] # paths to data as well as other metadata from each inference run
         self.paths = self.make_paths(root)
-        self.session_log = self.paths['logs'] / f'session.log'
-        self.log_event('Initialized session')
-        self.manifest_json = self.paths['logs'] / f'manifest.json'
-        open(self.manifest_json, 'w').close() # instantiate empty json file
+
+        self.logfile = self.paths['logs'] / f'session.log'
+        logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format)
+
+        self.log_info('Initialized session')
+        self.tables = {}
+
+    def write_to_table(self, name: str, coords: dict, data: pd.DataFrame):
+        """
+        Write data to a named data table, initializing if it does not yet exist.
+        :param name: name of the table to persist through session
+        :param coords: dictionary of coordinates to associate with all rows in this method call
+        :param data: DataFrame containing data
+        :return: True if successful
+        """
+        try:
+            if name in self.tables.keys():
+                table = self.tables.get(name)
+            else:
+                table = CsvTable(self.paths['tables'] / (name + '.csv'))
+                self.tables[name] = table
+        except Exception:
+            raise CouldNotCreateTable(f'Unable to create table named {name}')
+
+        try:
+            table.append(coords, data)
+            return True
+        except Exception:
+            raise CouldNotAppendToTable(f'Unable to append data to table named {name}')
 
     def get_paths(self):
         return self.paths
@@ -55,7 +101,7 @@ class Session(object):
             root_path = Path(root)
         sid = Session.create_session_id(root_path)
         paths = {'root': root_path}
-        for pk in ['inbound_images', 'outbound_images', 'logs']:
+        for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']:
             pa = root_path / sid / model_server.conf.defaults.subdirectories[pk]
             paths[pk] = pa
             try:
@@ -75,22 +121,23 @@ class Session(object):
             idx += 1
         return f'{yyyymmdd}-{idx:04d}'
 
-    def log_event(self, event: str):
-        """
-        Write an event string to this session's log file.
-        """
-        timestamp = strftime('%m/%d/%Y, %H:%M:%S', localtime())
-        with open(self.session_log, 'a') as fh:
-            fh.write(f'{timestamp} -- {event}\n')
+    def get_log_data(self) -> list:
+        log = []
+        with open(self.logfile, 'r') as fh:
+            for line in fh:
+                k = ['datatime', 'level', 'message']
+                v = line.strip().split(' - ')[0:3]
+                log.insert(0, dict(zip(k, v)))
+        return log
 
-    def record_workflow_run(self, record: WorkflowRunRecord or None):
-        """
-        Append a JSON describing inference data to this session's manifest
-        """
-        self.log_event(f'Ran model {record.model_id} on {record.input_filepath} to infer {record.output_filepath}')
-        with open(self.manifest_json, 'w+') as fh:
-            json.dump(record.dict(), fh)
+    def log_info(self, msg):
+        logger.info(msg)
+
+    def log_warning(self, msg):
+        logger.warning(msg)
 
+    def log_error(self, msg):
+        logger.error(msg)
 
     def load_model(self, ModelClass: Model, params: Dict[str, str] = None) ->  dict:
         """
@@ -114,7 +161,7 @@ class Session(object):
             'object': mi,
             'params': params
         }
-        self.log_event(f'Loaded model {key}')
+        self.log_info(f'Loaded model {key}')
         return key
 
     def describe_loaded_models(self) -> dict:
@@ -157,5 +204,11 @@ class CouldNotInstantiateModelError(Error):
 class CouldNotCreateDirectory(Error):
     pass
 
+class CouldNotCreateTable(Error):
+    pass
+
+class CouldNotAppendToTable(Error):
+    pass
+
 class InvalidPathError(Error):
     pass
\ No newline at end of file
diff --git a/model_server/conf/defaults.py b/model_server/conf/defaults.py
index 55b114f2..bdf7cfd0 100644
--- a/model_server/conf/defaults.py
+++ b/model_server/conf/defaults.py
@@ -6,6 +6,7 @@ subdirectories = {
     'logs': 'logs',
     'inbound_images': 'images/inbound',
     'outbound_images': 'images/outbound',
+    'tables': 'tables',
 }
 
 server_conf = {
diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py
index 3d07931e..97a13aaf 100644
--- a/model_server/conf/testing.py
+++ b/model_server/conf/testing.py
@@ -67,6 +67,7 @@ roiset_test_data = {
         'c': 5,
         'z': 7,
         'mask_path': root / 'zmask-test-stack-mask.tif',
+        'mask_path_3d': root / 'zmask-test-stack-mask-3d.tif',
     },
     'pipeline_params': {
         'segmentation_channel': 0,
diff --git a/tests/test_api.py b/tests/test_api.py
index 13c62c36..3ac5c8f1 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -142,7 +142,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         self.assertEqual(resp_list_0.status_code, 200)
         rj0 = resp_list_0.json()
         self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}')
-        resp_restart = self._get('restart')
+        resp_restart = self._get('session/restart')
         resp_list_1 = self._get('models')
         rj1 = resp_list_1.json()
         self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')
@@ -177,4 +177,9 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         )
         self.assertEqual(resp_change.status_code, 200)
         resp_check = self._get('paths')
-        self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images'])
\ No newline at end of file
+        self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images'])
+
+    def test_get_logs(self):
+        resp = self._get('session/logs')
+        self.assertEqual(resp.status_code, 200)
+        self.assertEqual(resp.json()[0]['message'], 'Initialized session')
\ No newline at end of file
diff --git a/tests/test_session.py b/tests/test_session.py
index 9679aad6..aafda3c2 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -1,72 +1,83 @@
+import json
+from os.path import exists
 import pathlib
 import unittest
+
 from model_server.base.models import DummySemanticSegmentationModel
 from model_server.base.session import Session
+from model_server.base.workflows import WorkflowRunRecord
 
 class TestGetSessionObject(unittest.TestCase):
     def setUp(self) -> None:
-        pass
+        self.sesh = Session()
+
+    def tearDown(self) -> None:
+        print('Tearing down...')
+        Session._instances = {}
 
-    def test_single_session_instance(self):
-        sesh = Session()
-        self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object')
+    def test_session_is_singleton(self):
+        Session._instances = {}
+        self.assertEqual(len(Session._instances), 0)
+        s = Session()
+        self.assertEqual(len(Session._instances), 1)
+        self.assertIs(s, Session())
+        self.assertEqual(len(Session._instances), 1)
 
-        from os.path import exists
-        self.assertTrue(exists(sesh.session_log), 'Session did not create a log file in the correct place')
-        self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place')
+    def test_session_logfile_is_valid(self):
+        self.assertTrue(exists(self.sesh.logfile), 'Session did not create a log file in the correct place')
 
     def test_changing_session_root_creates_new_directory(self):
         from model_server.conf.defaults import root
         from shutil import rmtree
 
-        sesh = Session()
-        old_paths = sesh.get_paths()
+        old_paths = self.sesh.get_paths()
         newroot = root / 'subdir'
-        sesh.restart(root=newroot)
-        new_paths = sesh.get_paths()
+        self.sesh.restart(root=newroot)
+        new_paths = self.sesh.get_paths()
         for k in old_paths.keys():
             self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__()))
+
+        # this is necessary because logger itself is a singleton class
+        self.tearDown()
+        self.setUp()
         rmtree(newroot)
         self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
 
     def test_change_session_subdirectory(self):
-        sesh = Session()
-        old_paths = sesh.get_paths()
+        old_paths = self.sesh.get_paths()
         print(old_paths)
-        sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
-        self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images'])
-
-
-    def test_restart_session(self):
-        sesh = Session()
-        logfile1 = sesh.session_log
-        sesh.restart()
-        logfile2 = sesh.session_log
+        self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
+        self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
+
+    def test_restarting_session_creates_new_logfile(self):
+        logfile1 = self.sesh.logfile
+        self.assertTrue(logfile1.exists())
+        self.sesh.restart()
+        logfile2 = self.sesh.logfile
+        self.assertTrue(logfile2.exists())
         self.assertNotEqual(logfile1, logfile2, 'Restarting session does not generate new logfile')
 
-    def test_session_records_workflow(self):
-        import json
-        from model_server.base.workflows import WorkflowRunRecord
-        sesh = Session()
-        di = WorkflowRunRecord(
-            model_id='test_model',
-            input_filepath='/test/input/directory',
-            output_filepath='/test/output/fi.le',
-            success=True,
-            timer_results={'start': 0.123},
-        )
-        sesh.record_workflow_run(di)
-        with open(sesh.manifest_json, 'r') as fh:
-            do = json.load(fh)
-        self.assertEqual(di.dict(), do, 'Manifest record is not correct')
-
+    def test_log_warning(self):
+        msg = 'A test warning'
+        self.sesh.log_info(msg)
+        with open(self.sesh.logfile, 'r') as fh:
+            log = fh.read()
+        self.assertTrue(msg in log)
+
+    def test_get_logs(self):
+        self.sesh.log_info('Info example 1')
+        self.sesh.log_warning('Example warning')
+        self.sesh.log_info('Info example 2')
+        logs = self.sesh.get_log_data()
+        self.assertEqual(len(logs), 4)
+        self.assertEqual(logs[1]['level'], 'WARNING')
+        self.assertEqual(logs[-1]['message'], 'Initialized session')
 
     def test_session_loads_model(self):
-        sesh = Session()
         MC = DummySemanticSegmentationModel
-        success = sesh.load_model(MC)
+        success = self.sesh.load_model(MC)
         self.assertTrue(success)
-        loaded_models = sesh.describe_loaded_models()
+        loaded_models = self.sesh.describe_loaded_models()
         self.assertTrue(
             (MC.__name__ + '_00') in loaded_models.keys()
         )
@@ -76,46 +87,56 @@ class TestGetSessionObject(unittest.TestCase):
         )
 
     def test_session_loads_second_instance_of_same_model(self):
-        sesh = Session()
         MC = DummySemanticSegmentationModel
-        sesh.load_model(MC)
-        sesh.load_model(MC)
-        self.assertIn(MC.__name__ + '_00', sesh.models.keys())
-        self.assertIn(MC.__name__ + '_01', sesh.models.keys())
-
+        self.sesh.load_model(MC)
+        self.sesh.load_model(MC)
+        self.assertIn(MC.__name__ + '_00', self.sesh.models.keys())
+        self.assertIn(MC.__name__ + '_01', self.sesh.models.keys())
 
     def test_session_loads_model_with_params(self):
-        sesh = Session()
         MC = DummySemanticSegmentationModel
         p1 = {'p1': 'abc'}
-        success = sesh.load_model(MC, params=p1)
+        success = self.sesh.load_model(MC, params=p1)
         self.assertTrue(success)
-        loaded_models = sesh.describe_loaded_models()
+        loaded_models = self.sesh.describe_loaded_models()
         mid = MC.__name__ + '_00'
         self.assertEqual(loaded_models[mid]['params'], p1)
 
         # load a second model and confirm that the first is locatable by its param entry
         p2 = {'p2': 'def'}
-        sesh.load_model(MC, params=p2)
-        find_mid = sesh.find_param_in_loaded_models('p1', 'abc')
+        self.sesh.load_model(MC, params=p2)
+        find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc')
         self.assertEqual(mid, find_mid)
-        self.assertEqual(sesh.describe_loaded_models()[mid]['params'], p1)
+        self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)
 
     def test_session_finds_existing_model_with_different_path_formats(self):
-        sesh = Session()
         MC = DummySemanticSegmentationModel
         p1 = {'path': 'c:\\windows\\dummy.pa'}
         p2 = {'path': 'c:/windows/dummy.pa'}
-        mid = sesh.load_model(MC, params=p1)
+        mid = self.sesh.load_model(MC, params=p1)
         assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
-        find_mid = sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
+        find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
         self.assertEqual(mid, find_mid)
 
     def test_change_output_path(self):
-        import pathlib
-        sesh = Session()
-        pa = sesh.get_paths()['inbound_images']
+        pa = self.sesh.get_paths()['inbound_images']
         self.assertIsInstance(pa, pathlib.Path)
-        sesh.set_data_directory('outbound_images', pa.__str__())
-        self.assertEqual(sesh.paths['inbound_images'], sesh.paths['outbound_images'])
-        self.assertIsInstance(sesh.paths['outbound_images'], pathlib.Path)
\ No newline at end of file
+        self.sesh.set_data_directory('outbound_images', pa.__str__())
+        self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images'])
+        self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path)
+
+    def test_make_table(self):
+        import pandas as pd
+        data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)]
+        self.sesh.write_to_table(
+            'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4])
+        )
+        self.assertTrue(self.sesh.tables['test_numbers'].path.exists())
+        self.sesh.write_to_table(
+            'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8])
+        )
+
+        dfv = pd.read_csv(self.sesh.tables['test_numbers'].path)
+        self.assertEqual(len(dfv), len(data))
+        self.assertEqual(dfv.columns[0], 'X')
+        self.assertEqual(dfv.columns[1], 'Y')
-- 
GitLab