Skip to content
Snippets Groups Projects
test_api.py 6.88 KiB
Newer Older
from multiprocessing import Process
import requests
import unittest
from model_server.conf.testing import czifile

class TestServerBaseClass(unittest.TestCase):
    def setUp(self) -> None:
        import uvicorn
        host = '127.0.0.1'

        self.server_process = Process(
            target=uvicorn.run,
            kwargs={'host': host, 'port': port, 'log_level': 'debug'},
            daemon=True
        )
        self.uri = f'http://{host}:{port}/'
        self.server_process.start()

        sesh = requests.Session()
        retries = Retry(
            total=5,
            backoff_factor=0.1,
        )
        sesh.mount('http://', requests.adapters.HTTPAdapter(max_retries=retries))
        return sesh

    def _get(self, endpoint):
        return self._get_sesh().get(self.uri + endpoint)

    def _put(self, endpoint, query=None, body=None):
        return self._get_sesh().put(
            self.uri + endpoint,
            params=query,
            data=json.dumps(body)
        )
    def copy_input_file_to_server(self):
        from shutil import copyfile

        pa = resp.json()['inbound_images']
        outpath = Path(pa) / czifile['filename']
        copyfile(
            czifile['path'],
            outpath
        )

    def tearDown(self) -> None:
        self.server_process.terminate()
class TestApiFromAutomatedClient(TestServerBaseClass):
    def test_trivial_api_response(self):
        self.assertEqual(resp.status_code, 200)

    def test_bounceback_parameters(self):
        resp = self._put('bounce_back', body={'par1': 'hello', 'par2': ['ab', 'cd']})
        self.assertEqual(resp.status_code, 200, resp.json())
        self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json())
        self.assertEqual(resp.json()['params']['par2'], ['ab', 'cd'], resp.json())
    def test_default_session_paths(self):
        import model_server.conf.defaults
        conf_root = model_server.conf.defaults.root
        for p in ['inbound_images', 'outbound_images', 'logs']:
            self.assertTrue(resp.json()[p].startswith(conf_root.__str__()))
            suffix = Path(model_server.conf.defaults.subdirectories[p]).__str__()
            self.assertTrue(resp.json()[p].endswith(suffix))

    def test_list_empty_loaded_models(self):
        self.assertEqual(resp.status_code, 200)
        self.assertEqual(resp.content, b'{}')
    def test_load_dummy_semantic_model(self):
        resp_load = self._put(f'models/dummy_semantic/load')
        model_id = resp_load.json()['model_id']
        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        self.assertEqual(rj[model_id]['class'], 'DummySemanticSegmentationModel')
    def test_load_dummy_instance_model(self):
        resp_load = self._put(f'models/dummy_instance/load')
        model_id = resp_load.json()['model_id']
        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list = self._get('models')
        self.assertEqual(resp_list.status_code, 200)
        rj = resp_list.json()
        self.assertEqual(rj[model_id]['class'], 'DummyInstanceSegmentationModel')
        return model_id

    def test_respond_with_error_when_invalid_filepath_requested(self):
            query={'model_id': model_id, 'input_filename': 'not_a_real_file.name'}
        )
        self.assertEqual(resp.status_code, 404, resp.content.decode())

    def test_i2i_inference_errors_when_model_not_found(self):
           query={'model_id': model_id, 'input_filename': 'not_a_real_file.name'}
        self.assertEqual(resp.status_code, 409, resp.content.decode())

    def test_i2i_dummy_inference_by_api(self):
        self.copy_input_file_to_server()
        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
    def test_restarting_session_clears_loaded_models(self):
        resp_load = self._put(f'models/dummy_semantic/load',)
        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        self.assertEqual(resp_list_0.status_code, 200)
        rj0 = resp_list_0.json()
        self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}')
        resp_restart = self._get('session/restart')
        self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')

    def test_change_inbound_path(self):
            query={'path': resp_inpath.json()['inbound_images']}
        )
        self.assertEqual(resp_change.status_code, 200)
        self.assertEqual(resp_check.json()['inbound_images'], resp_check.json()['outbound_images'])

    def test_exception_when_changing_inbound_path(self):
        fakepath = 'c:/fake/path/to/nowhere'
        )
        self.assertEqual(resp_change.status_code, 404)
        self.assertIn(fakepath, resp_change.json()['detail'])
        self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images'])

    def test_no_change_inbound_path(self):
            query={'path': resp_inpath.json()['outbound_images']}
        )
        self.assertEqual(resp_change.status_code, 200)
        self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images'])

    def test_get_logs(self):
        resp = self._get('session/logs')
        self.assertEqual(resp.status_code, 200)
        self.assertEqual(resp.json()[0]['message'], 'Initialized session')