diff --git a/api.py b/api.py index ff82c0448cdbeb7368f10a34e129afde7cbd1450..ace225b4df645a8bde6eb65d8acafe0adfcbc604 100644 --- a/api.py +++ b/api.py @@ -19,20 +19,16 @@ def read_root(): return {'success': True} @app.put('/bounce_back') -def read_root(par1=None, par2=None): +def list_bounce_back(par1=None, par2=None): return {'success': True, 'params': {'par1': par1, 'par2': par2}} @app.get('/paths') def list_session_paths(): - pass - -@app.put('/paths/{path_id}') -def set_path_id(path_id: str, abs_path: str): - pass + return session.get_paths() @app.get('/restart') -def restart_session() -> dict: - session.restart() +def restart_session(root: str = None) -> dict: + session.restart(root=root) return session.describe_loaded_models() @app.get('/models') @@ -68,7 +64,7 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: status_code=409, detail=f'Model {model_id} has not been loaded' ) - inpath = session.inbound.path / input_filename + inpath = session.paths['inbound_images'] / input_filename if not inpath.exists(): raise HTTPException( status_code=404, @@ -77,7 +73,7 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: record = infer_image_to_image( inpath, session.models[model_id]['object'], - session.outbound.path, + session.paths['outbound_images'], channel=channel, ) session.record_workflow_run(record) diff --git a/conf/defaults.py b/conf/defaults.py index c1761b7a6712bcb6a8213c4672348401c8cf8cc8..39e07c4ec28fcdb269cda0b1378a02e4544b720c 100644 --- a/conf/defaults.py +++ b/conf/defaults.py @@ -1,14 +1,16 @@ from pathlib import Path root = Path('c:/Users/rhodes/projects/proj0015-model-server/resources') -paths = { - 'logs': root / 'logs' / 'session', - 'inbound_images': root / 'images' / 'inbound', - 'outbound_images': root / 'images' / 'outbound', - 'ilastik_projects': root / 'ilastik', -} +# paths = { +# 'logs': root / 'logs' / 'session', +# 'inbound_images': root / 'images' / 'inbound', +# 'outbound_images': root / 'images' / 'outbound', +# 'ilastik_projects': root / 'ilastik', +# } -for gk in paths.keys(): - paths[gk].mkdir(parents=True, exist_ok=True) - -# TODO: consider configuring paths via API, and this just becomes a HTTP client script \ No newline at end of file +subdirectories = { + 'logs': 'logs/session', + 'inbound_images': 'images/inbound', + 'outbound_images': 'images/outbound', + 'ilastik_projects': 'ilastik', +} \ No newline at end of file diff --git a/model_server/session.py b/model_server/session.py index 989fc1b3f0d1297ebfcd9a4ffdccf802f5050d79..f52257434faae2ed06499595c56d672ead2f7f07 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -5,7 +5,7 @@ from pathlib import Path from time import strftime, localtime from typing import Dict -from conf.defaults import paths +import conf.defaults from model_server.model import Model from model_server.share import SharedImageDirectory from model_server.workflow import WorkflowRunRecord @@ -17,25 +17,51 @@ class Session(object): """ Singleton class for a server session that persists data between API calls """ - inbound = SharedImageDirectory(paths['inbound_images']) - outbound = SharedImageDirectory(paths['outbound_images']) - where_records = paths['logs'] + # inbound = SharedImageDirectory(conf.defaults.paths['inbound_images']) + # outbound = SharedImageDirectory(conf.defaults.paths['outbound_images']) + # where_records = conf.defaults.paths['logs'] def __new__(cls): if not hasattr(cls, 'instance'): cls.instance = super(Session, cls).__new__(cls) return cls.instance - def __init__(self): + def __init__(self, root: str = None): print('Initializing session') self.models = {} # model_id : model object self.manifest = [] # paths to data as well as other metadata from each inference run - self.session_id = self.create_session_id(self.where_records) - self.session_log = self.where_records / f'{self.session_id}.log' + self.paths = self.make_paths(root) + self.session_id = self.create_session_id(self.paths['logs']) + self.session_log = self.paths['logs'] / f'{self.session_id}.log' self.log_event('Initialized session') - self.manifest_json = self.where_records / f'{self.session_id}-manifest.json' + self.manifest_json = self.paths['logs'] / f'{self.session_id}-manifest.json' open(self.manifest_json, 'w').close() # instantiate empty json file + def get_paths(self): + return self.paths + + @staticmethod + def make_paths(root: str = None) -> dict: + """ + Set paths where images, logs, etc. are located in this session + :param root: absolute path to top-level directory + :return: dictionary of session paths + """ + paths = {} + if root is None: + root_path = Path(conf.defaults.root) + else: + root_path = Path(root) + sid = Session.create_session_id(root_path) + for pk in ['inbound_images', 'outbound_images', 'logs']: + pa = root_path / sid / conf.defaults.subdirectories[pk] + paths[pk] = pa + try: + pa.mkdir(parents=True, exist_ok=True) + except Exception: + raise CouldNotCreateDirectory(f'Could not create directory: {pa}') + return paths + @staticmethod def create_session_id(look_where: Path) -> str: """ @@ -98,8 +124,8 @@ class Session(object): for k in self.models.keys() } - def restart(self): - self.__init__() + def restart(self, **kwargs): + self.__init__(**kwargs) class Error(Exception): pass @@ -108,4 +134,7 @@ class InferenceRecordError(Error): pass class CouldNotInstantiateModelError(Error): + pass + +class CouldNotCreateDirectory(Error): pass \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index 8c6c248bdecb847943b668ca755cab57b223bdc7..d09108e6ea9a65d960440c7bde8ec8e30574f69f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,4 +1,5 @@ from multiprocessing import Process +from pathlib import Path import requests import unittest @@ -20,18 +21,6 @@ class TestServerBaseClass(unittest.TestCase): self.uri = f'http://{host}:{port}/' self.server_process.start() - @staticmethod - def copy_input_file_to_server(): - from shutil import copyfile - from conf.defaults import paths - - outpath = paths['inbound_images'] / czifile['filename'] - - copyfile( - czifile['path'], - outpath - ) - def tearDown(self) -> None: self.server_process.terminate() @@ -49,8 +38,15 @@ class TestApiFromAutomatedClient(TestServerBaseClass): def test_default_session_paths(self): import conf.defaults resp = requests.get(self.uri + 'paths') - for p in ['inbound', 'outbound', 'logs']: - self.assertEqual(resp.json()['paths'][p], conf.defaults.paths[p]) + conf_root = conf.defaults.root + for p in ['inbound_images', 'outbound_images', 'logs']: + self.assertTrue(resp.json()[p].startswith(conf_root.__str__())) + suffix = Path(conf.defaults.subdirectories[p]).__str__() + self.assertTrue(resp.json()[p].endswith(suffix)) + + def test_restart_with_new_root_directory(self): + pass + # TODO: implement def test_list_empty_loaded_models(self): resp = requests.get(self.uri + 'models') @@ -93,6 +89,17 @@ class TestApiFromAutomatedClient(TestServerBaseClass): ) self.assertEqual(resp.status_code, 409, resp.content.decode()) + def copy_input_file_to_server(self): + from shutil import copyfile + + resp = requests.get(self.uri + 'paths') + pa = resp.json()['inbound_images'] + outpath = Path(pa) / czifile['filename'] + copyfile( + czifile['path'], + outpath + ) + def test_i2i_dummy_inference_by_api(self): model_id = self.test_load_dummy_model() self.copy_input_file_to_server() diff --git a/tests/test_session.py b/tests/test_session.py index 6a15c68be361d5e9eab2a668c5b186d359dc1047..9846dcfe83aea5b7a128da71f30db66d4f0c254e 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -14,6 +14,17 @@ class TestGetSessionObject(unittest.TestCase): self.assertTrue(exists(sesh.session_log), 'Session did not create a log file in the correct place') self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place') + def test_can_change_session_root(self): + from conf.defaults import root + sesh = Session() + old_paths = sesh.get_paths() + newroot = root / 'subdir' + sesh.restart(root=newroot) + new_paths = sesh.get_paths() + + for k in old_paths.keys(): + self.assertTrue(new_paths[k].__str__().startswith(newroot.__str__())) + def test_restart_session(self): sesh = Session() logfile1 = sesh.session_log