from multiprocessing import Process import requests import unittest from conf.testing import czifile from model_server.model import DummyImageToImageModel class TestServerBaseClass(unittest.TestCase): def setUp(self) -> None: import uvicorn host = '127.0.0.1' port = 5000 self.server_process = Process( target=uvicorn.run, args=('api:app', ), kwargs={'host': host, 'port': port, 'log_level': 'debug'}, daemon=True ) self.uri = f'http://{host}:{port}/' self.server_process.start() @staticmethod def copy_input_file_to_server(): import pathlib from shutil import copyfile from conf.server import paths outpath = pathlib.Path(paths['images']['inbound'] / czifile['filename']) copyfile( czifile['path'], outpath ) def tearDown(self) -> None: self.server_process.terminate() class TestApiFromAutomatedClient(TestServerBaseClass): def test_trivial_api_response(self): resp = requests.get(self.uri, ) self.assertEqual(resp.status_code, 200) def test_bounceback_parameters(self): resp = requests.put(self.uri + 'bounce_back', params={'par1': 'hello'}) self.assertEqual(resp.status_code, 200, resp.json()) self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json()) self.assertEqual(resp.json()['params']['par2'], None, resp.json()) def test_list_empty_loaded_models(self): resp = requests.get(self.uri + 'models') self.assertEqual(resp.status_code, 200) self.assertEqual(resp.content, b'{}') def test_load_dummy_model(self): resp_load = requests.put( self.uri + f'models/dummy/load', ) model_id = resp_load.json()['model_id'] self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') return model_id def test_respond_with_error_when_invalid_filepath_requested(self): model_id = self.test_load_dummy_model() resp = requests.put( self.uri + f'infer/from_image_file', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' } ) self.assertEqual(resp.status_code, 404, resp.content.decode()) def test_i2i_inference_errors_when_model_not_found(self): model_id = 'not_a_real_model' resp = requests.put( self.uri + f'infer/from_image_file', params={ 'model_id': model_id, 'input_filename': 'not_a_real_file.name' } ) self.assertEqual(resp.status_code, 409, resp.content.decode()) def test_i2i_dummy_inference_by_api(self): model_id = self.test_load_dummy_model() self.copy_input_file_to_server() resp_infer = requests.put( self.uri + f'infer/from_image_file', params={ 'model_id': model_id, 'input_filename': czifile['filename'], 'channel': 2, }, ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) def test_restarting_session_clears_loaded_models(self): resp_load = requests.put( self.uri + f'models/dummy/load', ) self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list_0 = requests.get(self.uri + 'models') self.assertEqual(resp_list_0.status_code, 200) rj0 = resp_list_0.json() self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}') resp_restart = requests.get(self.uri + 'restart') resp_list_1 = requests.get(self.uri + 'models') rj1 = resp_list_1.json() self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')