Skip to content
Snippets Groups Projects
test_api.py 4.02 KiB
Newer Older
from multiprocessing import Process
import requests
import unittest
from conf.testing import czifile
from model_server.model import DummyImageToImageModel
class TestServerBaseClass(unittest.TestCase):
    def setUp(self) -> None:
        import uvicorn
        host = '127.0.0.1'
        port = 5000

        self.server_process = Process(
            target=uvicorn.run,
            args=('api:app', ),
            kwargs={'host': host, 'port': port, 'log_level': 'debug'},
            daemon=True
        )
        self.uri = f'http://{host}:{port}/'
        self.server_process.start()

    @staticmethod
    def copy_input_file_to_server():
        import pathlib
        from shutil import copyfile
        from conf.server import paths

        outpath = pathlib.Path(paths['images']['inbound'] / czifile['filename'])

        copyfile(
            czifile['path'],
            outpath
        )

    def tearDown(self) -> None:
        self.server_process.terminate()

class TestApiFromAutomatedClient(TestServerBaseClass):
    def test_trivial_api_response(self):
        resp = requests.get(self.uri, )
        self.assertEqual(resp.status_code, 200)

    def test_bounceback_parameters(self):
        resp = requests.put(self.uri + 'bounce_back', params={'par1': 'hello'})
        self.assertEqual(resp.status_code, 200, resp.json())
        self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json())
        self.assertEqual(resp.json()['params']['par2'], None, resp.json())

    def test_list_empty_loaded_models(self):
        resp = requests.get(self.uri + 'models')
        self.assertEqual(resp.status_code, 200)
        self.assertEqual(resp.content, b'{}')
        model_id = resp_load.json()['model_id']
        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list = requests.get(self.uri + 'models')
        self.assertEqual(resp_list.status_code, 200)
        self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
        return model_id
    def test_respond_with_error_when_invalid_filepath_requested(self):
        model_id = self.test_load_dummy_model()

            self.uri + f'infer/from_image_file',
            params={
                'model_id': model_id,
                'input_filename': 'not_a_real_file.name'
            }
        )
        self.assertEqual(resp.status_code, 404, resp.content.decode())

    def test_i2i_inference_errors_when_model_not_found(self):
        model_id = 'not_a_real_model'
            self.uri + f'infer/from_image_file',
            params={
                'model_id': model_id,
                'input_filename': 'not_a_real_file.name'
            }
        self.assertEqual(resp.status_code, 409, resp.content.decode())

    def test_i2i_dummy_inference_by_api(self):
        model_id = self.test_load_dummy_model()
        self.copy_input_file_to_server()
        resp_infer = requests.put(
            self.uri + f'infer/from_image_file',
                'input_filename': czifile['filename'],
                'channel': 2,
            },
        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
    def test_restarting_session_clears_loaded_models(self):
        resp_load = requests.put(
            self.uri + f'models/dummy/load',
        )
        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list_0 = requests.get(self.uri + 'models')
        self.assertEqual(resp_list_0.status_code, 200)
        rj0 = resp_list_0.json()
        self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}')
        resp_restart = requests.get(self.uri + 'restart')
        resp_list_1 = requests.get(self.uri + 'models')
        rj1 = resp_list_1.json()
        self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')