-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
session.py 6.43 KiB
import logging
import os
from pathlib import Path, PureWindowsPath
from pydantic import BaseModel
from time import strftime, localtime
from typing import Union
import pandas as pd
from ..conf import defaults
from .models import Model
logger = logging.getLogger(__name__)
class CsvTable(object):
def __init__(self, fpath: Path):
self.path = fpath
self.empty = True
def append(self, coords: dict, data: pd.DataFrame) -> bool:
assert isinstance(data, pd.DataFrame)
for c in reversed(coords.keys()):
data.insert(0, c, coords[c])
if self.empty:
data.to_csv(self.path, index=False, mode='w', header=True)
else:
data.to_csv(self.path, index=False, mode='a', header=False)
self.empty = False
return True
class _Session(object):
"""
Singleton class for a server session that persists data between API calls
"""
log_format = '%(asctime)s - %(levelname)s - %(message)s'
def __init__(self, root: str = None):
self.models = {} # model_id : model object
self.paths = self.make_paths(root)
self.logfile = self.paths['logs'] / f'session.log'
logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format)
self.log_info('Initialized session')
self.tables = {}
def write_to_table(self, name: str, coords: dict, data: pd.DataFrame):
"""
Write data to a named data table, initializing if it does not yet exist.
:param name: name of the table to persist through session
:param coords: dictionary of coordinates to associate with all rows in this method call
:param data: DataFrame containing data
:return: True if successful
"""
try:
if name in self.tables.keys():
table = self.tables.get(name)
else:
table = CsvTable(self.paths['tables'] / (name + '.csv'))
self.tables[name] = table
except Exception:
raise CouldNotCreateTable(f'Unable to create table named {name}')
try:
table.append(coords, data)
return True
except Exception:
raise CouldNotAppendToTable(f'Unable to append data to table named {name}')
def get_paths(self):
return self.paths
def set_data_directory(self, key: str, path: str):
if not key in self.paths.keys():
raise InvalidPathError(f'No such path {key}')
if not Path(path).exists():
raise InvalidPathError(f'Could not find {path}')
self.paths[key] = Path(path)
@staticmethod
def make_paths(root: str = None) -> dict:
"""
Set paths where images, logs, etc. are located in this session
:param root: absolute path to top-level directory
:return: dictionary of session paths
"""
if root is None:
root_path = Path(defaults.root)
else:
root_path = Path(root)
sid = _Session.create_session_id(root_path)
paths = {'root': root_path}
for pk in ['inbound_images', 'outbound_images', 'logs', 'tables']:
pa = root_path / sid / defaults.subdirectories[pk]
paths[pk] = pa
try:
pa.mkdir(parents=True, exist_ok=True)
except Exception:
raise CouldNotCreateDirectory(f'Could not create directory: {pa}')
return paths
@staticmethod
def create_session_id(look_where: Path) -> str:
"""
Autogenerate a session ID by incrementing from a list of log files.
"""
yyyymmdd = strftime('%Y%m%d', localtime())
idx = 0
while os.path.exists(look_where / f'{yyyymmdd}-{idx:04d}'):
idx += 1
return f'{yyyymmdd}-{idx:04d}'
def get_log_data(self) -> list:
log = []
with open(self.logfile, 'r') as fh:
for line in fh:
k = ['datatime', 'level', 'message']
v = line.strip().split(' - ')[0:3]
log.insert(0, dict(zip(k, v)))
return log
def log_info(self, msg):
logger.info(msg)
def log_warning(self, msg):
logger.warning(msg)
def log_error(self, msg):
logger.error(msg)
# TODO: load, describe, unload data accessors
def load_model(self, ModelClass: Model, params: Union[BaseModel, None] = None) -> dict:
"""
Load an instance of a given model class and attach to this session's model registry
:param ModelClass: subclass of Model
:param params: optional parameters that are passed to the model's construct
:return: model_id of loaded model
"""
mi = ModelClass(params=params)
assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
ii = 0
def mid(i):
return f'{ModelClass.__name__}_{i:02d}'
while mid(ii) in self.models.keys():
ii += 1
key = mid(ii)
self.models[key] = {
'object': mi,
'params': getattr(mi, 'params', None)
}
self.log_info(f'Loaded model {key}')
return key
def describe_loaded_models(self) -> dict:
return {
k: {
'class': self.models[k]['object'].__class__.__name__,
'params': self.models[k]['params'],
}
for k in self.models.keys()
}
def find_param_in_loaded_models(self, key: str, value: str, is_path=False) -> str:
"""
Returns model_id of first model where key and value match with .params field, or None
:param is_path: uses platform-independent path comparison if True
"""
models = self.describe_loaded_models()
for mid, det in models.items():
if is_path:
if PureWindowsPath(det.get('params').get(key)).as_posix() == Path(value).as_posix():
return mid
else:
if det.get('params').get(key) == value:
return mid
return None
def restart(self, **kwargs):
self.__init__(**kwargs)
# create singleton instance
session = _Session()
class Error(Exception):
pass
class InferenceRecordError(Error):
pass
class CouldNotInstantiateModelError(Error):
pass
class CouldNotCreateDirectory(Error):
pass
class CouldNotCreateTable(Error):
pass
class CouldNotAppendToTable(Error):
pass
class InvalidPathError(Error):
pass