diff --git a/model_server/base/session.py b/model_server/base/session.py index 18c8b61d039dfa5e16e3f618e70c6c35a97c4c7c..3d5b789de7f644fa6f16fe9f9170b83d540f2328 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -5,6 +5,8 @@ from pathlib import Path from time import strftime, localtime from typing import Dict +import pandas as pd + import model_server.conf.defaults from model_server.base.models import Model @@ -19,6 +21,22 @@ class Singleton(type): cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) return cls._instances[cls] +class CsvTable(object): + def __init__(self, fpath: Path): + self.path = fpath + self.empty = True + + def append(self, coords: dict, data: pd.DataFrame) -> bool: + assert isinstance(data, pd.DataFrame) + for c in reversed(coords.keys()): + data.insert(0, c, coords[c]) + if self.empty: + data.to_csv(self.path, index=False, mode='w', header=True) + else: + data.to_csv(self.path, index=False, mode='a', header=False) + self.empty = False + return True + class Session(object, metaclass=Singleton): """ Singleton class for a server session that persists data between API calls @@ -35,6 +53,30 @@ class Session(object, metaclass=Singleton): logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format) self.log_info('Initialized session') + self.tables = {} + + def write_to_table(self, name: str, coords: dict, data: pd.DataFrame): + """ + Write data to a named data table, initializing if it does not yet exist. + :param name: name of the table to persist through session + :param coords: dictionary of coordinates to associate with all rows in this method call + :param data: DataFrame containing data + :return: True if successful + """ + try: + if name in self.tables.keys(): + table = self.tables.get(name) + else: + table = CsvTable(self.paths['tables'] / (name + '.csv')) + self.tables[name] = table + except Exception: + raise CouldNotCreateTable(f'Unable to create table named {name}') + + try: + table.append(coords, data) + return True + except Exception: + raise CouldNotAppendToTable(f'Unable to append data to table named {name}') def get_paths(self): return self.paths @@ -59,7 +101,7 @@ class Session(object, metaclass=Singleton): root_path = Path(root) sid = Session.create_session_id(root_path) paths = {'root': root_path} - for pk in ['inbound_images', 'outbound_images', 'logs']: + for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']: pa = root_path / sid / model_server.conf.defaults.subdirectories[pk] paths[pk] = pa try: @@ -162,5 +204,11 @@ class CouldNotInstantiateModelError(Error): class CouldNotCreateDirectory(Error): pass +class CouldNotCreateTable(Error): + pass + +class CouldNotAppendToTable(Error): + pass + class InvalidPathError(Error): pass \ No newline at end of file diff --git a/model_server/conf/defaults.py b/model_server/conf/defaults.py index 55b114f2686e9575fb2e3eb77b9e0577f40b5918..bdf7cfd0cf2786783b8f4c16dafdf7418af05976 100644 --- a/model_server/conf/defaults.py +++ b/model_server/conf/defaults.py @@ -6,6 +6,7 @@ subdirectories = { 'logs': 'logs', 'inbound_images': 'images/inbound', 'outbound_images': 'images/outbound', + 'tables': 'tables', } server_conf = { diff --git a/tests/test_session.py b/tests/test_session.py index 2e5eb44fbb7272b7c4a4ef7d32129320dc710256..aafda3c27f7146edc9e5235e660bc311d579e938 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -123,4 +123,20 @@ class TestGetSessionObject(unittest.TestCase): self.assertIsInstance(pa, pathlib.Path) self.sesh.set_data_directory('outbound_images', pa.__str__()) self.assertEqual(self.sesh.paths['inbound_images'], self.sesh.paths['outbound_images']) - self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path) \ No newline at end of file + self.assertIsInstance(self.sesh.paths['outbound_images'], pathlib.Path) + + def test_make_table(self): + import pandas as pd + data = [{'modulo': i % 2, 'times one hundred': i * 100} for i in range(0, 8)] + self.sesh.write_to_table( + 'test_numbers', {'X': 0, 'Y': 0}, pd.DataFrame(data[0:4]) + ) + self.assertTrue(self.sesh.tables['test_numbers'].path.exists()) + self.sesh.write_to_table( + 'test_numbers', {'X': 1, 'Y': 1}, pd.DataFrame(data[4:8]) + ) + + dfv = pd.read_csv(self.sesh.tables['test_numbers'].path) + self.assertEqual(len(dfv), len(data)) + self.assertEqual(dfv.columns[0], 'X') + self.assertEqual(dfv.columns[1], 'Y')