Skip to content
Snippets Groups Projects
Commit e2e7ea21 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Successfully tested calling inference over API

parent 487b9b84
No related branches found
No related tags found
No related merge requests found
...@@ -30,19 +30,24 @@ def load_model(model_id: str) -> dict: ...@@ -30,19 +30,24 @@ def load_model(model_id: str) -> dict:
session.load_model(model_id) session.load_model(model_id)
return session.describe_models() return session.describe_models()
@app.post('/i2i/infer/{model_id}') # image file in, image file out @app.put('/i2i/infer/{model_id}') # image file in, image file out
def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
if model_id not in session.models.keys(): if model_id not in session.describe_models().keys():
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
detail=f'Model {model_id} has not been loaded' detail=f'Model {model_id} has not been loaded'
) )
inpath = session.inbound.path / input_filename
if not inpath.exists():
raise HTTPException(
status_code=404,
detail=f'Could not find file:\n{inpath}'
)
# TODO: try block workflow, catch and redirect HTTP
record = infer_image_to_image( record = infer_image_to_image(
session.inbound / imgf, inpath,
session.models[model_id], session.models[model_id],
session.outbound, session.outbound.path,
channel=channel, channel=channel,
# TODO: optional callback for status reporting # TODO: optional callback for status reporting
) )
......
from pathlib import Path from pathlib import Path
root = Path('c:/Users/rhodes/projects/proj0015-model-server/resources') root = Path('c:/Users/rhodes/projects/proj0015-model-server/resources')
filename = 'Selection--W0000--P0001-T0001.czi'
czifile = { czifile = {
'path': root / 'testdata' / 'Selection--W0000--P0001-T0001.czi', 'filename': filename,
'path': root / 'testdata' / filename,
'w': 1024, 'w': 1024,
'h': 1024, 'h': 1024,
'c': 4, 'c': 4,
......
...@@ -54,17 +54,17 @@ def generate_file_accessor(fpath): ...@@ -54,17 +54,17 @@ def generate_file_accessor(fpath):
else: else:
raise FileAccessorError(f'Could not match a file accessor with {fpath}') raise FileAccessorError(f'Could not match a file accessor with {fpath}')
def Error(Exception): class Error(Exception):
pass pass
def FileAccessorError(Error): class FileAccessorError(Error):
pass pass
def FileNotFoundError(Error): class FileNotFoundError(Error):
pass pass
def FileShapeError(Error): class FileShapeError(Error):
pass pass
def FileWriteError(Error): class FileWriteError(Error):
pass pass
\ No newline at end of file
...@@ -72,7 +72,7 @@ class Session(object): ...@@ -72,7 +72,7 @@ class Session(object):
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc() mi = mc()
assert mi.loaded assert mi.loaded
self.models['model_id'] = mi self.models[model_id] = mi
return True return True
raise CouldNotFindModelError( raise CouldNotFindModelError(
f'Could not find {model_id} in:\n{models}', f'Could not find {model_id} in:\n{models}',
......
import os import os
import pathlib
def validate_directory_exists(path): def validate_directory_exists(path):
return os.path.exists(path) return os.path.exists(path)
class SharedImageDirectory(object): class SharedImageDirectory(object):
def __init__(self, path): def __init__(self, path: pathlib.Path):
self.path = path self.path = path
self.is_memory_mapped = False self.is_memory_mapped = False
......
from multiprocessing import Process from multiprocessing import Process
import requests import requests
import unittest import unittest
from conf.testing import czifile, output_path from conf.testing import czifile
from model_server.model import DummyImageToImageModel from model_server.model import DummyImageToImageModel
class TestApiFromAutomatedClient(unittest.TestCase): class TestApiFromAutomatedClient(unittest.TestCase):
...@@ -20,6 +21,19 @@ class TestApiFromAutomatedClient(unittest.TestCase): ...@@ -20,6 +21,19 @@ class TestApiFromAutomatedClient(unittest.TestCase):
self.uri = f'http://{host}:{port}/' self.uri = f'http://{host}:{port}/'
self.server_process.start() self.server_process.start()
@staticmethod
def copy_input_file_to_server():
import pathlib
from shutil import copyfile
from conf.server import paths
outpath = pathlib.Path(paths['images']['inbound'] / czifile['filename'])
copyfile(
czifile['path'],
outpath
)
def tearDown(self) -> None: def tearDown(self) -> None:
self.server_process.terminate() self.server_process.terminate()
...@@ -38,19 +52,24 @@ class TestApiFromAutomatedClient(unittest.TestCase): ...@@ -38,19 +52,24 @@ class TestApiFromAutomatedClient(unittest.TestCase):
self.assertEqual(resp_load.status_code, 200) self.assertEqual(resp_load.status_code, 200)
resp_list = requests.get(self.uri + 'models') resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200) self.assertEqual(resp_list.status_code, 200)
self.assertEqual(resp_list.content, b'{"model_id":"DummyImageToImageModel"}') self.assertEqual(resp_list.content, b'{"dummy_make_white_square":"DummyImageToImageModel"}')
def test_i2i_inference_errors_model_not_sound(self): def test_i2i_inference_errors_model_not_found(self):
model_id = 'not_a_real_model' model_id = 'not_a_real_model'
resp = requests.post(self.uri + f'i2i/infer/{model_id}') resp = requests.put(
self.assertEqual(resp.status_code, 404) self.uri + f'i2i/infer/{model_id}',
params={'input_filename': 'not_a_real_file.name'}
)
print(resp.content)
self.assertEqual(resp.status_code, 409)
def test_i2i_dummy_inference_by_api(self): def test_i2i_dummy_inference_by_api(self):
model = DummyImageToImageModel() model = DummyImageToImageModel()
model_id = model.model_id resp_load = requests.get(self.uri + f'models/{model.model_id}/load')
resp = requests.post( self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}')
self.uri + f'/i2i/infer/{model_id}', self.copy_input_file_to_server()
str(czifile['path']), resp_infer = requests.put(
self.uri + f'i2i/infer/{model.model_id}',
params={'input_filename': czifile['filename']},
) )
print(resp) self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model.model_id}')
self.assertEqual(resp.status_code, 200) \ No newline at end of file
\ No newline at end of file
...@@ -37,9 +37,14 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -37,9 +37,14 @@ class TestGetSessionObject(unittest.TestCase):
do = json.load(fh) do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct') self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_session_load_model(self): def test_session_loads_model(self):
sesh = Session() sesh = Session()
success = sesh.load_model(DummyImageToImageModel.model_id) model_id = DummyImageToImageModel.model_id
success = sesh.load_model(model_id)
self.assertTrue(success) self.assertTrue(success)
self.assertTrue('model_id' in sesh.models.keys()) loaded_models = sesh.describe_models()
self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel) self.assertTrue(model_id in loaded_models.keys())
\ No newline at end of file self.assertEqual(
loaded_models[model_id],
DummyImageToImageModel.__name__
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment