From e2e7ea215281e9dfd9f245671f715f60daac2037 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 28 Aug 2023 14:23:04 +0200 Subject: [PATCH] Successfully tested calling inference over API --- api.py | 17 +++++++++++------ conf/testing.py | 5 +++-- model_server/image.py | 10 +++++----- model_server/session.py | 2 +- model_server/share.py | 3 ++- tests/test_api.py | 41 ++++++++++++++++++++++++++++++----------- tests/test_session.py | 13 +++++++++---- 7 files changed, 61 insertions(+), 30 deletions(-) diff --git a/api.py b/api.py index 5a5f25aa..f5d7976f 100644 --- a/api.py +++ b/api.py @@ -30,19 +30,24 @@ def load_model(model_id: str) -> dict: session.load_model(model_id) return session.describe_models() -@app.post('/i2i/infer/{model_id}') # image file in, image file out -def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: - if model_id not in session.models.keys(): +@app.put('/i2i/infer/{model_id}') # image file in, image file out +def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: + if model_id not in session.describe_models().keys(): raise HTTPException( status_code=409, detail=f'Model {model_id} has not been loaded' ) + inpath = session.inbound.path / input_filename + if not inpath.exists(): + raise HTTPException( + status_code=404, + detail=f'Could not find file:\n{inpath}' + ) - # TODO: try block workflow, catch and redirect HTTP record = infer_image_to_image( - session.inbound / imgf, + inpath, session.models[model_id], - session.outbound, + session.outbound.path, channel=channel, # TODO: optional callback for status reporting ) diff --git a/conf/testing.py b/conf/testing.py index 753153a3..94a42aa8 100644 --- a/conf/testing.py +++ b/conf/testing.py @@ -1,9 +1,10 @@ from pathlib import Path root = Path('c:/Users/rhodes/projects/proj0015-model-server/resources') - +filename = 'Selection--W0000--P0001-T0001.czi' czifile = { - 'path': root / 'testdata' / 'Selection--W0000--P0001-T0001.czi', + 'filename': filename, + 'path': root / 'testdata' / filename, 'w': 1024, 'h': 1024, 'c': 4, diff --git a/model_server/image.py b/model_server/image.py index ec9dd677..bd6b1abc 100644 --- a/model_server/image.py +++ b/model_server/image.py @@ -54,17 +54,17 @@ def generate_file_accessor(fpath): else: raise FileAccessorError(f'Could not match a file accessor with {fpath}') -def Error(Exception): +class Error(Exception): pass -def FileAccessorError(Error): +class FileAccessorError(Error): pass -def FileNotFoundError(Error): +class FileNotFoundError(Error): pass -def FileShapeError(Error): +class FileShapeError(Error): pass -def FileWriteError(Error): +class FileWriteError(Error): pass \ No newline at end of file diff --git a/model_server/session.py b/model_server/session.py index 66a3ac11..f3b65d62 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -72,7 +72,7 @@ class Session(object): if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: mi = mc() assert mi.loaded - self.models['model_id'] = mi + self.models[model_id] = mi return True raise CouldNotFindModelError( f'Could not find {model_id} in:\n{models}', diff --git a/model_server/share.py b/model_server/share.py index e817febc..e270dea6 100644 --- a/model_server/share.py +++ b/model_server/share.py @@ -1,11 +1,12 @@ import os +import pathlib def validate_directory_exists(path): return os.path.exists(path) class SharedImageDirectory(object): - def __init__(self, path): + def __init__(self, path: pathlib.Path): self.path = path self.is_memory_mapped = False diff --git a/tests/test_api.py b/tests/test_api.py index 17afe2e3..ecc0296c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,8 +1,9 @@ from multiprocessing import Process + import requests import unittest -from conf.testing import czifile, output_path +from conf.testing import czifile from model_server.model import DummyImageToImageModel class TestApiFromAutomatedClient(unittest.TestCase): @@ -20,6 +21,19 @@ class TestApiFromAutomatedClient(unittest.TestCase): self.uri = f'http://{host}:{port}/' self.server_process.start() + @staticmethod + def copy_input_file_to_server(): + import pathlib + from shutil import copyfile + from conf.server import paths + + outpath = pathlib.Path(paths['images']['inbound'] / czifile['filename']) + + copyfile( + czifile['path'], + outpath + ) + def tearDown(self) -> None: self.server_process.terminate() @@ -38,19 +52,24 @@ class TestApiFromAutomatedClient(unittest.TestCase): self.assertEqual(resp_load.status_code, 200) resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) - self.assertEqual(resp_list.content, b'{"model_id":"DummyImageToImageModel"}') + self.assertEqual(resp_list.content, b'{"dummy_make_white_square":"DummyImageToImageModel"}') - def test_i2i_inference_errors_model_not_sound(self): + def test_i2i_inference_errors_model_not_found(self): model_id = 'not_a_real_model' - resp = requests.post(self.uri + f'i2i/infer/{model_id}') - self.assertEqual(resp.status_code, 404) + resp = requests.put( + self.uri + f'i2i/infer/{model_id}', + params={'input_filename': 'not_a_real_file.name'} + ) + print(resp.content) + self.assertEqual(resp.status_code, 409) def test_i2i_dummy_inference_by_api(self): model = DummyImageToImageModel() - model_id = model.model_id - resp = requests.post( - self.uri + f'/i2i/infer/{model_id}', - str(czifile['path']), + resp_load = requests.get(self.uri + f'models/{model.model_id}/load') + self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}') + self.copy_input_file_to_server() + resp_infer = requests.put( + self.uri + f'i2i/infer/{model.model_id}', + params={'input_filename': czifile['filename']}, ) - print(resp) - self.assertEqual(resp.status_code, 200) \ No newline at end of file + self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model.model_id}') \ No newline at end of file diff --git a/tests/test_session.py b/tests/test_session.py index 5d10b07c..2538d3d2 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -37,9 +37,14 @@ class TestGetSessionObject(unittest.TestCase): do = json.load(fh) self.assertEqual(di.dict(), do, 'Manifest record is not correct') - def test_session_load_model(self): + def test_session_loads_model(self): sesh = Session() - success = sesh.load_model(DummyImageToImageModel.model_id) + model_id = DummyImageToImageModel.model_id + success = sesh.load_model(model_id) self.assertTrue(success) - self.assertTrue('model_id' in sesh.models.keys()) - self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel) \ No newline at end of file + loaded_models = sesh.describe_models() + self.assertTrue(model_id in loaded_models.keys()) + self.assertEqual( + loaded_models[model_id], + DummyImageToImageModel.__name__ + ) \ No newline at end of file -- GitLab