diff --git a/api.py b/api.py index d130b0632cdd30733b548b9390d167d80ba635a2..e62758b072db59306aeb149e4392343b0cca1bdd 100644 --- a/api.py +++ b/api.py @@ -22,6 +22,11 @@ def read_root(): def read_root(par1=None, par2=None): return {'success': True, 'params': {'par1': par1, 'par2': par2}} +@app.get('/restart') +def restart_session() -> dict: + session.restart() + return session.describe_loaded_models() + @app.get('/models') def list_active_models(): return session.describe_loaded_models() diff --git a/imagej/infer_ilastik_by_api.py b/imagej/infer_ilastik_by_api.py index f974f8a517e88e743226f60a9c5fb7f16a432d07..c4f0dda32b2929495835a93f64ec57e18f009c2e 100644 --- a/imagej/infer_ilastik_by_api.py +++ b/imagej/infer_ilastik_by_api.py @@ -1,5 +1,4 @@ import httplib -import urllib2 import json import urllib @@ -18,33 +17,32 @@ input_filename = os.path.split(abspath)[-1] outpath = 'C:\\Users\\rhodes\\projects\\proj0015-model-server\\resources\\testdata' -global connection -connection = httplib.HTTPConnection('127.0.0.1', 8001) -def hit_endpoint(method, endpoint, params=None): - if not method in ['GET', 'PUT']: - raise Exception('Can only handle GET and PUT requests') - if params: - url = endpoint + '?' + urllib.urlencode(params) - else: - url = endpoint - connection.request(method, url) - resp = connection.getresponse() - print(method + ' ' + url + ', status ' + str(resp.status) + ':\n' + str(json.loads(resp.read()))) - return resp - -hit_endpoint('GET', '/') -hit_endpoint('GET', '/models') -hit_endpoint('PUT', '/bounce_back', {'par1': 'ghij'}) -hit_endpoint('PUT', '/models/ilastik/pixel_classification/load/', {'project_file': 'demo_px.ilp'}) -resp = hit_endpoint('GET', '/models') - -print(resp.read()) -#print(json.loads(resp.read())) # trying to extract model_id, but json.loads throws an error - -#infer_params = { -# 'model_id': model_id, -# 'input_filename': input_filename, -# 'channel': 0 -# } -# -#hit_endpoint('PUT', '/infer/from_image_file', infer_params) \ No newline at end of file +def hit_endpoint(method, endpoint, params=None, verbose=False): + connection = httplib.HTTPConnection(host, port) + if not method in ['GET', 'PUT']: + raise Exception('Can only handle GET and PUT requests') + if params: + url = endpoint + '?' + urllib.urlencode(params) + else: + url = endpoint + connection.request(method, url) + resp = connection.getresponse() + resp_str = resp.read() + if verbose: + print(method + ' ' + url + ', status ' + str(resp.status) + ':\n' + resp_str) + return json.loads(resp_str) + +#hit_endpoint('GET', '/') +#hit_endpoint('GET', '/models') +#hit_endpoint('PUT', '/bounce_back', {'par1': 'ghij'}) +resp = hit_endpoint('PUT', '/models/ilastik/pixel_classification/load/', {'project_file': 'demo_px.ilp'}) +pxmid = resp['model_id'] +resp = hit_endpoint('GET', '/models', verbose=True) + +infer_params = { + 'model_id': pxmid, + 'input_filename': input_filename, + 'channel': 0 + } + +hit_endpoint('PUT', '/infer/from_image_file', infer_params) \ No newline at end of file diff --git a/run_server.py b/run_server.py index 1e90251c6650cac3be7b2383cf32fa61a4bdb41f..c76d91ff93aec8b881b89a43856e3356ef414432 100644 --- a/run_server.py +++ b/run_server.py @@ -4,4 +4,4 @@ host = '127.0.0.1' port = 8001 if __name__ == '__main__': - uvicorn.run('api:app', **{'host': host, 'port': port, 'log_level': 'debug'}, reload=True) \ No newline at end of file + uvicorn.run('api:app', **{'host': host, 'port': port, 'log_level': 'debug'}, reload=False) \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index ae2529f8455d741c8e544d4665d4db0da5fc7640..b7835ffd2e5b7f4e24834d690c6223d9f21b1791 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -64,7 +64,6 @@ class TestApiFromAutomatedClient(TestServerBaseClass): self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') return model_id - def test_respond_with_error_when_invalid_filepath_requested(self): model_id = self.test_load_dummy_model() @@ -102,3 +101,16 @@ class TestApiFromAutomatedClient(TestServerBaseClass): ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + def test_restarting_session_clears_loaded_models(self): + resp_load = requests.put( + self.uri + f'models/dummy/load', + ) + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + resp_list_0 = requests.get(self.uri + 'models') + self.assertEqual(resp_list_0.status_code, 200) + rj0 = resp_list_0.json() + self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}') + resp_restart = requests.get(self.uri + 'restart') + resp_list_1 = requests.get(self.uri + 'models') + rj1 = resp_list_1.json() + self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}') \ No newline at end of file