diff --git a/conf/defaults.py b/conf/defaults.py index 9323b984322cbd82538408bd01cde34e606bfe7a..5d22087c21a7076152f31b1e0e0eb21f08a05504 100644 --- a/conf/defaults.py +++ b/conf/defaults.py @@ -6,4 +6,9 @@ subdirectories = { 'logs': 'logs', 'inbound_images': 'images/inbound', 'outbound_images': 'images/outbound', +} + +server_conf = { + 'host': '127.0.0.1', + 'port': 8000, } \ No newline at end of file diff --git a/model_server/api.py b/model_server/api.py index fdec6903976aa79ace5630cb9ffcbdf9848ae341..6e7fb5c38e0ca4f48ef237ce8cd91aa18df20228 100644 --- a/model_server/api.py +++ b/model_server/api.py @@ -1,7 +1,7 @@ -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from model_server.models import DummyImageToImageModel -from model_server.session import Session +from model_server.session import Session, InvalidPathError from model_server.validators import validate_workflow_inputs from model_server.workflows import infer_image_to_image from extensions.ilastik.workflows import infer_px_then_ob_model @@ -28,6 +28,24 @@ def list_bounce_back(par1=None, par2=None): def list_session_paths(): return session.get_paths() +def change_path(key, path): + try: + session.set_data_directory(key, path) + except InvalidPathError as e: + raise HTTPException( + status_code=404, + detail=e.__str__(), + ) + return session.get_paths() + +@app.put('/paths/watch_input') +def watch_input_path(path: str): + return change_path('inbound_images', path) + +@app.put('/paths/watch_output') +def watch_input_path(path: str): + return change_path('outbound_images', path) + @app.get('/restart') def restart_session(root: str = None) -> dict: session.restart(root=root) diff --git a/model_server/session.py b/model_server/session.py index 3d5db0bc16cb37fbb01f505ae143282294779e03..15ee33f4479107a45edcfacedf1cc5c63fcbe123 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -36,6 +36,13 @@ class Session(object): def get_paths(self): return self.paths + def set_data_directory(self, key: str, path: Path): + if not key in self.paths.keys(): + raise InvalidPathError(f'No such path {key}') + if not Path(path).exists(): + raise InvalidPathError(f'Could not find {path}') + self.paths[key] = path + @staticmethod def make_paths(root: str = None) -> dict: """ @@ -143,4 +150,7 @@ class CouldNotInstantiateModelError(Error): pass class CouldNotCreateDirectory(Error): + pass + +class InvalidPathError(Error): pass \ No newline at end of file diff --git a/scripts/run_server.py b/scripts/run_server.py index 52f4f0612e088be0484c357dbadafc1519628e5a..70b73880a251b614f0bed0c51dcdb1ef0ef6cf08 100644 --- a/scripts/run_server.py +++ b/scripts/run_server.py @@ -2,19 +2,21 @@ import argparse from multiprocessing import Process import uvicorn +from conf.defaults import server_conf + def parse_args(): parser = argparse.ArgumentParser( description='Start model server with optional arguments', ) parser.add_argument( '--host', - default='127.0.0.1', + default=server_conf['host'], help='bind socket to this host' ) parser.add_argument( '--port', - default='8000', - help='bind socket to this port, default=8000', + default=str(server_conf['port']), + help='bind socket to this port', ) parser.add_argument( '--debug', @@ -46,6 +48,7 @@ if __name__ == '__main__': print('Running in debug mode') print('Type "STOP" to stop server') input_str = '' - while input_str.upper() != 'STOP': input_str = input() + while input_str.upper() != 'STOP': + input_str = input() print('Finished') diff --git a/tests/test_api.py b/tests/test_api.py index e8e110563306633db6d3560a6bde2b0c389609bb..f1431ab54db067ecfb9114283a8bb81500020068 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -121,4 +121,38 @@ class TestApiFromAutomatedClient(TestServerBaseClass): resp_restart = requests.get(self.uri + 'restart') resp_list_1 = requests.get(self.uri + 'models') rj1 = resp_list_1.json() - self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}') \ No newline at end of file + self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}') + + def test_change_inbound_path(self): + resp_inpath = requests.get( + self.uri + 'paths' + ) + resp_change = requests.put( + self.uri + f'paths/watch_output', + params={ + 'path': resp_inpath.json()['inbound_images'] + } + ) + self.assertEqual(resp_change.status_code, 200) + resp_check = requests.get( + self.uri + 'paths' + ) + self.assertEqual(resp_check.json()['inbound_images'], resp_check.json()['outbound_images']) + + def test_exception_when_changing_inbound_path(self): + resp_inpath = requests.get( + self.uri + 'paths' + ) + fakepath = 'c:/fake/path/to/nowhere' + resp_change = requests.put( + self.uri + f'paths/watch_output', + params={ + 'path': fakepath, + } + ) + self.assertEqual(resp_change.status_code, 404) + self.assertIn(fakepath, resp_change.json()['detail']) + resp_check = requests.get( + self.uri + 'paths' + ) + self.assertEqual(resp_inpath.json()['inbound_images'], resp_check.json()['inbound_images']) \ No newline at end of file diff --git a/tests/test_session.py b/tests/test_session.py index fd4b79c94c64b6bd11fa679244d766a6671a3ceb..dee15d0e447c43b64a8ae9c556dfc4de28e3f613 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -28,6 +28,13 @@ class TestGetSessionObject(unittest.TestCase): rmtree(newroot) self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory') + def test_change_session_subdirectory(self): + sesh = Session() + old_paths = sesh.get_paths() + print(old_paths) + sesh.set_data_directory('outbound_images', old_paths['inbound_images']) + self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images']) + def test_restart_session(self): sesh = Session()