Skip to content
Snippets Groups Projects
Commit d09bdffb authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Added endpoint to restart session, i.e. to clear loaded models

parent c712a665
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,11 @@ def read_root(): ...@@ -22,6 +22,11 @@ def read_root():
def read_root(par1=None, par2=None): def read_root(par1=None, par2=None):
return {'success': True, 'params': {'par1': par1, 'par2': par2}} return {'success': True, 'params': {'par1': par1, 'par2': par2}}
@app.get('/restart')
def restart_session() -> dict:
session.restart()
return session.describe_loaded_models()
@app.get('/models') @app.get('/models')
def list_active_models(): def list_active_models():
return session.describe_loaded_models() return session.describe_loaded_models()
......
import httplib import httplib
import urllib2
import json import json
import urllib import urllib
...@@ -18,33 +17,32 @@ input_filename = os.path.split(abspath)[-1] ...@@ -18,33 +17,32 @@ input_filename = os.path.split(abspath)[-1]
outpath = 'C:\\Users\\rhodes\\projects\\proj0015-model-server\\resources\\testdata' outpath = 'C:\\Users\\rhodes\\projects\\proj0015-model-server\\resources\\testdata'
global connection def hit_endpoint(method, endpoint, params=None, verbose=False):
connection = httplib.HTTPConnection('127.0.0.1', 8001) connection = httplib.HTTPConnection(host, port)
def hit_endpoint(method, endpoint, params=None): if not method in ['GET', 'PUT']:
if not method in ['GET', 'PUT']: raise Exception('Can only handle GET and PUT requests')
raise Exception('Can only handle GET and PUT requests') if params:
if params: url = endpoint + '?' + urllib.urlencode(params)
url = endpoint + '?' + urllib.urlencode(params) else:
else: url = endpoint
url = endpoint connection.request(method, url)
connection.request(method, url) resp = connection.getresponse()
resp = connection.getresponse() resp_str = resp.read()
print(method + ' ' + url + ', status ' + str(resp.status) + ':\n' + str(json.loads(resp.read()))) if verbose:
return resp print(method + ' ' + url + ', status ' + str(resp.status) + ':\n' + resp_str)
return json.loads(resp_str)
hit_endpoint('GET', '/')
hit_endpoint('GET', '/models') #hit_endpoint('GET', '/')
hit_endpoint('PUT', '/bounce_back', {'par1': 'ghij'}) #hit_endpoint('GET', '/models')
hit_endpoint('PUT', '/models/ilastik/pixel_classification/load/', {'project_file': 'demo_px.ilp'}) #hit_endpoint('PUT', '/bounce_back', {'par1': 'ghij'})
resp = hit_endpoint('GET', '/models') resp = hit_endpoint('PUT', '/models/ilastik/pixel_classification/load/', {'project_file': 'demo_px.ilp'})
pxmid = resp['model_id']
print(resp.read()) resp = hit_endpoint('GET', '/models', verbose=True)
#print(json.loads(resp.read())) # trying to extract model_id, but json.loads throws an error
infer_params = {
#infer_params = { 'model_id': pxmid,
# 'model_id': model_id, 'input_filename': input_filename,
# 'input_filename': input_filename, 'channel': 0
# 'channel': 0 }
# }
# hit_endpoint('PUT', '/infer/from_image_file', infer_params)
#hit_endpoint('PUT', '/infer/from_image_file', infer_params) \ No newline at end of file
\ No newline at end of file
...@@ -4,4 +4,4 @@ host = '127.0.0.1' ...@@ -4,4 +4,4 @@ host = '127.0.0.1'
port = 8001 port = 8001
if __name__ == '__main__': if __name__ == '__main__':
uvicorn.run('api:app', **{'host': host, 'port': port, 'log_level': 'debug'}, reload=True) uvicorn.run('api:app', **{'host': host, 'port': port, 'log_level': 'debug'}, reload=False)
\ No newline at end of file \ No newline at end of file
...@@ -64,7 +64,6 @@ class TestApiFromAutomatedClient(TestServerBaseClass): ...@@ -64,7 +64,6 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
return model_id return model_id
def test_respond_with_error_when_invalid_filepath_requested(self): def test_respond_with_error_when_invalid_filepath_requested(self):
model_id = self.test_load_dummy_model() model_id = self.test_load_dummy_model()
...@@ -102,3 +101,16 @@ class TestApiFromAutomatedClient(TestServerBaseClass): ...@@ -102,3 +101,16 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
) )
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
def test_restarting_session_clears_loaded_models(self):
resp_load = requests.put(
self.uri + f'models/dummy/load',
)
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list_0 = requests.get(self.uri + 'models')
self.assertEqual(resp_list_0.status_code, 200)
rj0 = resp_list_0.json()
self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}')
resp_restart = requests.get(self.uri + 'restart')
resp_list_1 = requests.get(self.uri + 'models')
rj1 = resp_list_1.json()
self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment