Skip to content
Snippets Groups Projects
Commit e2e7ea21 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Successfully tested calling inference over API

parent 487b9b84
No related branches found
No related tags found
No related merge requests found
......@@ -30,19 +30,24 @@ def load_model(model_id: str) -> dict:
session.load_model(model_id)
return session.describe_models()
@app.post('/i2i/infer/{model_id}') # image file in, image file out
def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
if model_id not in session.models.keys():
@app.put('/i2i/infer/{model_id}') # image file in, image file out
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
if model_id not in session.describe_models().keys():
raise HTTPException(
status_code=409,
detail=f'Model {model_id} has not been loaded'
)
inpath = session.inbound.path / input_filename
if not inpath.exists():
raise HTTPException(
status_code=404,
detail=f'Could not find file:\n{inpath}'
)
# TODO: try block workflow, catch and redirect HTTP
record = infer_image_to_image(
session.inbound / imgf,
inpath,
session.models[model_id],
session.outbound,
session.outbound.path,
channel=channel,
# TODO: optional callback for status reporting
)
......
from pathlib import Path
root = Path('c:/Users/rhodes/projects/proj0015-model-server/resources')
filename = 'Selection--W0000--P0001-T0001.czi'
czifile = {
'path': root / 'testdata' / 'Selection--W0000--P0001-T0001.czi',
'filename': filename,
'path': root / 'testdata' / filename,
'w': 1024,
'h': 1024,
'c': 4,
......
......@@ -54,17 +54,17 @@ def generate_file_accessor(fpath):
else:
raise FileAccessorError(f'Could not match a file accessor with {fpath}')
def Error(Exception):
class Error(Exception):
pass
def FileAccessorError(Error):
class FileAccessorError(Error):
pass
def FileNotFoundError(Error):
class FileNotFoundError(Error):
pass
def FileShapeError(Error):
class FileShapeError(Error):
pass
def FileWriteError(Error):
class FileWriteError(Error):
pass
\ No newline at end of file
......@@ -72,7 +72,7 @@ class Session(object):
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc()
assert mi.loaded
self.models['model_id'] = mi
self.models[model_id] = mi
return True
raise CouldNotFindModelError(
f'Could not find {model_id} in:\n{models}',
......
import os
import pathlib
def validate_directory_exists(path):
return os.path.exists(path)
class SharedImageDirectory(object):
def __init__(self, path):
def __init__(self, path: pathlib.Path):
self.path = path
self.is_memory_mapped = False
......
from multiprocessing import Process
import requests
import unittest
from conf.testing import czifile, output_path
from conf.testing import czifile
from model_server.model import DummyImageToImageModel
class TestApiFromAutomatedClient(unittest.TestCase):
......@@ -20,6 +21,19 @@ class TestApiFromAutomatedClient(unittest.TestCase):
self.uri = f'http://{host}:{port}/'
self.server_process.start()
@staticmethod
def copy_input_file_to_server():
import pathlib
from shutil import copyfile
from conf.server import paths
outpath = pathlib.Path(paths['images']['inbound'] / czifile['filename'])
copyfile(
czifile['path'],
outpath
)
def tearDown(self) -> None:
self.server_process.terminate()
......@@ -38,19 +52,24 @@ class TestApiFromAutomatedClient(unittest.TestCase):
self.assertEqual(resp_load.status_code, 200)
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
self.assertEqual(resp_list.content, b'{"model_id":"DummyImageToImageModel"}')
self.assertEqual(resp_list.content, b'{"dummy_make_white_square":"DummyImageToImageModel"}')
def test_i2i_inference_errors_model_not_sound(self):
def test_i2i_inference_errors_model_not_found(self):
model_id = 'not_a_real_model'
resp = requests.post(self.uri + f'i2i/infer/{model_id}')
self.assertEqual(resp.status_code, 404)
resp = requests.put(
self.uri + f'i2i/infer/{model_id}',
params={'input_filename': 'not_a_real_file.name'}
)
print(resp.content)
self.assertEqual(resp.status_code, 409)
def test_i2i_dummy_inference_by_api(self):
model = DummyImageToImageModel()
model_id = model.model_id
resp = requests.post(
self.uri + f'/i2i/infer/{model_id}',
str(czifile['path']),
resp_load = requests.get(self.uri + f'models/{model.model_id}/load')
self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}')
self.copy_input_file_to_server()
resp_infer = requests.put(
self.uri + f'i2i/infer/{model.model_id}',
params={'input_filename': czifile['filename']},
)
print(resp)
self.assertEqual(resp.status_code, 200)
\ No newline at end of file
self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model.model_id}')
\ No newline at end of file
......@@ -37,9 +37,14 @@ class TestGetSessionObject(unittest.TestCase):
do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_session_load_model(self):
def test_session_loads_model(self):
sesh = Session()
success = sesh.load_model(DummyImageToImageModel.model_id)
model_id = DummyImageToImageModel.model_id
success = sesh.load_model(model_id)
self.assertTrue(success)
self.assertTrue('model_id' in sesh.models.keys())
self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel)
\ No newline at end of file
loaded_models = sesh.describe_models()
self.assertTrue(model_id in loaded_models.keys())
self.assertEqual(
loaded_models[model_id],
DummyImageToImageModel.__name__
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment