Skip to content
Snippets Groups Projects
test_api.py 8.61 KiB
from pathlib import Path

import model_server.conf.testing as conf
from model_server.conf.testing import TestServerBaseClass

czifile = conf.meta['image_files']['czifile']


class TestApiFromAutomatedClient(TestServerBaseClass):
    
    input_data = czifile
    
    def test_trivial_api_response(self):
        resp = self._get('')
        self.assertEqual(resp.status_code, 200)

    def test_bounceback_parameters(self):
        resp = self._put('testing/bounce_back', body={'par1': 'hello', 'par2': ['ab', 'cd']})
        self.assertEqual(resp.status_code, 200, resp.content)
        self.assertEqual(resp.json()['params']['par1'], 'hello', resp.json())
        self.assertEqual(resp.json()['params']['par2'], ['ab', 'cd'], resp.json())

    def test_default_session_paths(self):
        import model_server.conf.defaults
        resp = self._get('paths')
        conf_root = model_server.conf.defaults.root
        for p in ['inbound_images', 'outbound_images', 'logs']:
            self.assertTrue(resp.json()[p].startswith(conf_root.__str__()))
            suffix = Path(model_server.conf.defaults.subdirectories[p]).__str__()
            self.assertTrue(resp.json()[p].endswith(suffix))

    def test_list_empty_loaded_models(self):
        resp = self._get('models')
        self.assertEqual(resp.status_code, 200)
        self.assertEqual(resp.content, b'{}')

    def test_load_dummy_semantic_model(self):
        resp_load = self._put(f'testing/models/dummy_semantic/load')
        model_id = resp_load.json()['model_id']
        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list = self._get('models')
        self.assertEqual(resp_list.status_code, 200)
        rj = resp_list.json()
        self.assertEqual(rj[model_id]['class'], 'DummySemanticSegmentationModel')
        return model_id

    def test_load_dummy_instance_model(self):
        resp_load = self._put(f'testing/models/dummy_instance/load')
        model_id = resp_load.json()['model_id']
        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list = self._get('models')
        self.assertEqual(resp_list.status_code, 200)
        rj = resp_list.json()
        self.assertEqual(rj[model_id]['class'], 'DummyInstanceSegmentationModel')
        return model_id

    def test_respond_with_error_when_invalid_filepath_requested(self):
        model_id = self.test_load_dummy_semantic_model()

        resp = self._put(
            f'infer/from_image_file',
            query={'model_id': model_id, 'input_filename': 'not_a_real_file.name'}
        )
        self.assertEqual(resp.status_code, 404, resp.content.decode())

    def test_pipeline_errors_when_ids_not_found(self):
        fname = self.copy_input_file_to_server()
        model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id']
        in_acc_id = self._put(f'accessors/read_from_file/{fname}').json()

        # respond with 409 for invalid accessor_id
        self.assertEqual(
            self._put(
                f'pipelines/segment',
                body={'model_id': model_id, 'accessor_id': 'fake'}
            ).status_code,
            409
        )

        # respond with 409 for invalid model_id
        self.assertEqual(
            self._put(
                f'pipelines/segment',
                body={'model_id': 'fake', 'accessor_id': in_acc_id}
            ).status_code,
            409
        )


    def test_i2i_dummy_inference_by_api(self):
        fname = self.copy_input_file_to_server()
        model_id = self._put(f'testing/models/dummy_semantic/load').json()['model_id']
        in_acc_id = self._put(f'accessors/read_from_file/{fname}').json()

        # run segmentation pipeline on preloaded accessor
        resp_infer = self._put(
            f'pipelines/segment',
            body={
                'accessor_id': in_acc_id,
                'model_id': model_id,
                'channel': 2,
                'keep_interm': True,
            },
        )
        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
        out_acc_id = resp_infer.json()['output_accessor_id']
        self.assertTrue(self._get(f'accessors/{out_acc_id}').json()['loaded'])
        acc_out = self.get_accessor(out_acc_id, 'dummy_semantic_output.tif')
        self.assertEqual(acc_out.shape_dict['C'], 1)

        # validate intermediate data
        resp_list = self._get(f'accessors').json()
        self.assertEqual(len([k for k in resp_list.keys() if '_step' in k]), 2)

    def test_restarting_session_clears_loaded_models(self):
        resp_load = self._put(f'testing/models/dummy_semantic/load',)
        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list_0 = self._get('models')
        self.assertEqual(resp_list_0.status_code, 200)
        rj0 = resp_list_0.json()
        self.assertEqual(len(rj0), 1, f'Unexpected models in response: {rj0}')
        resp_restart = self._get('session/restart')
        resp_list_1 = self._get('models')
        rj1 = resp_list_1.json()
        self.assertEqual(len(rj1), 0, f'Unexpected models in response: {rj1}')

    def test_change_inbound_path(self):
        resp_inpath = self._get('paths')
        resp_change = self._put(
            f'paths/watch_output',
            query={'path': resp_inpath.json()['inbound_images']}
        )
        self.assertEqual(resp_change.status_code, 200)
        resp_check = self._get('paths')
        self.assertEqual(resp_check.json()['inbound_images'], resp_check.json()['outbound_images'])

    def test_exception_when_changing_inbound_path(self):
        resp_inpath = self._get('paths')
        fakepath = 'c:/fake/path/to/nowhere'
        resp_change = self._put(
            f'paths/watch_output',
            query={'path': fakepath}
        )
        self.assertEqual(resp_change.status_code, 404)
        self.assertIn(fakepath, resp_change.json()['detail'])
        resp_check = self._get('paths')
        self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images'])

    def test_no_change_inbound_path(self):
        resp_inpath = self._get('paths')
        resp_change = self._put(
            f'paths/watch_output',
            query={'path': resp_inpath.json()['outbound_images']}
        )
        self.assertEqual(resp_change.status_code, 200)
        resp_check = self._get('paths')
        self.assertEqual(resp_inpath.json()['outbound_images'], resp_check.json()['outbound_images'])

    def test_get_logs(self):
        resp = self._get('session/logs')
        self.assertEqual(resp.status_code, 200)
        self.assertEqual(resp.json()[0]['message'], 'Initialized session')

    def test_add_and_delete_accessor(self):
        fname = self.copy_input_file_to_server()

        # add accessor to session
        resp_add_acc = self._put(
            f'accessors/read_from_file/{fname}',
        )
        acc_id = resp_add_acc.json()
        self.assertTrue(acc_id.startswith('auto_'))

        # confirm that accessor is listed in session context
        resp_list_acc = self._get(
            f'accessors',
        )
        self.assertEqual(len(resp_list_acc.json()), 1)
        self.assertTrue(list(resp_list_acc.json().keys())[0].startswith('auto_'))
        self.assertTrue(resp_list_acc.json()[acc_id]['loaded'])

        # delete and check that its 'loaded' state changes
        self.assertTrue(self._get(f'accessors/{acc_id}').json()['loaded'])
        self.assertEqual(self._get(f'accessors/delete/{acc_id}').json(), acc_id)
        self.assertFalse(self._get(f'accessors/{acc_id}').json()['loaded'])

        # and try a non-existent accessor ID
        resp_wrong_acc = self._get('accessors/auto_123456')
        self.assertEqual(resp_wrong_acc.status_code, 404)

        # load another... then remove all
        self._put(f'accessors/read_from_file/{fname}')
        self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 1)
        self.assertEqual(len(self._get(f'accessors/delete/*').json()), 1)
        self.assertEqual(sum([v['loaded'] for v in self._get('accessors').json().values()]), 0)


    def test_empty_accessor_list(self):
        resp_list_acc = self._get(
            f'accessors',
        )
        self.assertEqual(len(resp_list_acc.json()), 0)

    def test_write_accessor(self):
        acc_id = self._put('/testing/accessors/dummy_accessor/load').json()
        self.assertTrue(self._get(f'accessors/{acc_id}').json()['loaded'])
        sd = self._get(f'accessors/{acc_id}').json()['shape_dict']
        self.assertEqual(self._get(f'accessors/{acc_id}').json()['filepath'], '')
        acc_out = self.get_accessor(accessor_id=acc_id, filename='test_output.tif')
        self.assertEqual(sd, acc_out.shape_dict)