Skip to content
Snippets Groups Projects
Commit d2ea41f9 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge in path-change API

parent 84a9187e
No related branches found
No related tags found
No related merge requests found
......@@ -6,4 +6,9 @@ subdirectories = {
'logs': 'logs',
'inbound_images': 'images/inbound',
'outbound_images': 'images/outbound',
}
server_conf = {
'host': '127.0.0.1',
'port': 8000,
}
\ No newline at end of file
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from model_server.models import DummyImageToImageModel
from model_server.session import Session
from model_server.session import Session, InvalidPathError
from model_server.validators import validate_workflow_inputs
from model_server.workflows import infer_image_to_image
from extensions.ilastik.workflows import infer_px_then_ob_model
......@@ -28,6 +28,24 @@ def list_bounce_back(par1=None, par2=None):
def list_session_paths():
return session.get_paths()
def change_path(key, path):
try:
session.set_data_directory(key, path)
except InvalidPathError as e:
raise HTTPException(
status_code=404,
detail=e.__str__(),
)
return session.get_paths()
@app.put('/paths/watch_input')
def watch_input_path(path: str):
return change_path('inbound_images', path)
@app.put('/paths/watch_output')
def watch_input_path(path: str):
return change_path('outbound_images', path)
@app.get('/restart')
def restart_session(root: str = None) -> dict:
session.restart(root=root)
......
......@@ -36,6 +36,13 @@ class Session(object):
def get_paths(self):
return self.paths
def set_data_directory(self, key: str, path: Path):
if not key in self.paths.keys():
raise InvalidPathError(f'No such path {key}')
if not Path(path).exists():
raise InvalidPathError(f'Could not find {path}')
self.paths[key] = path
@staticmethod
def make_paths(root: str = None) -> dict:
"""
......@@ -143,4 +150,7 @@ class CouldNotInstantiateModelError(Error):
pass
class CouldNotCreateDirectory(Error):
pass
class InvalidPathError(Error):
pass
\ No newline at end of file
......@@ -2,19 +2,21 @@ import argparse
from multiprocessing import Process
import uvicorn
from conf.defaults import server_conf
def parse_args():
parser = argparse.ArgumentParser(
description='Start model server with optional arguments',
)
parser.add_argument(
'--host',
default='127.0.0.1',
default=server_conf['host'],
help='bind socket to this host'
)
parser.add_argument(
'--port',
default='8000',
help='bind socket to this port, default=8000',
default=str(server_conf['port']),
help='bind socket to this port',
)
parser.add_argument(
'--debug',
......@@ -46,6 +48,7 @@ if __name__ == '__main__':
print('Running in debug mode')
print('Type "STOP" to stop server')
input_str = ''
while input_str.upper() != 'STOP': input_str = input()
while input_str.upper() != 'STOP':
input_str = input()
print('Finished')
......@@ -121,4 +121,38 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
resp_restart = requests.get(self.uri + 'restart')
resp_list_1 = requests.get(self.uri + 'models')
rj1 = resp_list_1.json()
self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')
\ No newline at end of file
self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')
def test_change_inbound_path(self):
resp_inpath = requests.get(
self.uri + 'paths'
)
resp_change = requests.put(
self.uri + f'paths/watch_output',
params={
'path': resp_inpath.json()['inbound_images']
}
)
self.assertEqual(resp_change.status_code, 200)
resp_check = requests.get(
self.uri + 'paths'
)
self.assertEqual(resp_check.json()['inbound_images'], resp_check.json()['outbound_images'])
def test_exception_when_changing_inbound_path(self):
resp_inpath = requests.get(
self.uri + 'paths'
)
fakepath = 'c:/fake/path/to/nowhere'
resp_change = requests.put(
self.uri + f'paths/watch_output',
params={
'path': fakepath,
}
)
self.assertEqual(resp_change.status_code, 404)
self.assertIn(fakepath, resp_change.json()['detail'])
resp_check = requests.get(
self.uri + 'paths'
)
self.assertEqual(resp_inpath.json()['inbound_images'], resp_check.json()['inbound_images'])
\ No newline at end of file
......@@ -28,6 +28,13 @@ class TestGetSessionObject(unittest.TestCase):
rmtree(newroot)
self.assertFalse(newroot.exists(), 'Could not clean up temporary test subdirectory')
def test_change_session_subdirectory(self):
sesh = Session()
old_paths = sesh.get_paths()
print(old_paths)
sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
self.assertEqual(sesh.paths['outbound_images'], sesh.paths['inbound_images'])
def test_restart_session(self):
sesh = Session()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment