Skip to content
Snippets Groups Projects
Commit 125847bc authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Reconfigured inference endpoint as pure PUT; can now pass params to model inference method

parent e2e7ea21
No related branches found
No related tags found
No related merge requests found
from pathlib import Path
from typing import Dict
from fastapi import FastAPI, HTTPException
......@@ -18,21 +18,21 @@ def read_root():
@app.get('/models')
def list_active_models():
return session.describe_models()
return session.describe_loaded_models()
@app.get('/models/{model_id}/load/')
def load_model(model_id: str) -> dict:
@app.put('/models/load/')
def load_model(model_id: str, params: Dict[str, str] = None) -> dict:
if model_id in session.models.keys():
raise HTTPException(
status_code=409,
detail=f'Model with id {model_id} has already been loaded'
)
session.load_model(model_id)
return session.describe_models()
session.load_model(model_id, params=params)
return session.describe_loaded_models()
@app.put('/i2i/infer/{model_id}') # image file in, image file out
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
if model_id not in session.describe_models().keys():
if model_id not in session.describe_loaded_models().keys():
raise HTTPException(
status_code=409,
detail=f'Model {model_id} has not been loaded'
......@@ -43,10 +43,10 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
status_code=404,
detail=f'Could not find file:\n{inpath}'
)
model = session.models[model_id]['object']
record = infer_image_to_image(
inpath,
session.models[model_id],
session.models[model_id]['object'],
session.outbound.path,
channel=channel,
# TODO: optional callback for status reporting
......
from abc import ABC, abstractmethod
from math import floor
import os
import numpy as np
......@@ -8,13 +9,15 @@ from model_server.image import GenericImageFileAccessor
class Model(ABC):
def __init__(self, autoload=True):
def __init__(self, autoload=True, params=None):
"""
Abstract base class for an inference model that uses image data as an input.
:param autoload: automatically load model and dependencies into memory if True
:param params: Dict[str, str] of arguments e.g. configuration files required to load model
"""
self.autoload = autoload
self.params = params
if self.load():
self.loaded = True
else:
......@@ -58,20 +61,20 @@ class Model(ABC):
class ImageToImageModel(Model):
def __init__(self, **kwargs):
"""
Abstract class for models that receive an image and return an image of the same size
:param kwargs: variable length keyword arguments
"""
return super().__init__(**kwargs)
"""
Abstract class for models that receive an image and return an image of the same size
"""
@abstractmethod
def infer(self, img, channel=None) -> (np.ndarray, dict):
super().infer(img, channel)
class IlastikImageToImageModel(ImageToImageModel):
pass
def load(self):
if 'project_file' not in self.params or not os.path.exists(self.params['project_file']):
raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
self.project_file = self.params['project_file']
class DummyImageToImageModel(ImageToImageModel):
......@@ -96,4 +99,7 @@ class ChannelTooHighError(Error):
pass
class CouldNotLoadModelError(Error):
pass
class ParameterExpectedError(Error):
pass
\ No newline at end of file
......@@ -3,6 +3,7 @@ import os
from pathlib import Path
from time import strftime, localtime
from typing import Dict
from conf.server import paths
from model_server.model import Model
......@@ -61,27 +62,37 @@ class Session(object):
with open(self.manifest_json, 'w+') as fh:
json.dump(record.dict(), fh)
def load_model(self, model_id: str) -> bool:
def load_model(self, model_id: str, params: Dict[str, str] = None) -> bool:
"""
Load an instance of first model class that matches model_id string
:param model_id:
:param model_id: string that uniquely defines a class of model
:param params: optional parameters that are passed upon loading a model
:return: True if model successfully loaded, False if not
"""
models = Model.get_all_subclasses()
for mc in models:
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc()
assert mi.loaded
self.models[model_id] = mi
try:
mi = mc(params)
assert mi.loaded
except:
raise CouldNotInstantiateModelError()
self.models[model_id] = {
'object': mi,
'params': params,
}
return True
raise CouldNotFindModelError(
f'Could not find {model_id} in:\n{models}',
)
return False
def describe_models(self) -> dict:
def describe_loaded_models(self) -> dict:
return {
k: self.models[k].__class__.__name__
k: {
'class': self.models[k]['object'].__class__.__name__,
'params': self.models[k]['params'],
}
for k in self.models.keys()
}
......@@ -95,4 +106,7 @@ class CouldNotFindModelError(Error):
pass
class InferenceRecordError(Error):
pass
class CouldNotInstantiateModelError(Error):
pass
\ No newline at end of file
......@@ -48,11 +48,16 @@ class TestApiFromAutomatedClient(unittest.TestCase):
def test_load_model(self):
model_id = DummyImageToImageModel.model_id
resp_load = requests.get(self.uri + f'models/{model_id}/load')
resp_load = requests.put(
self.uri + f'models/load',
params={'model_id': model_id}
)
self.assertEqual(resp_load.status_code, 200)
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
self.assertEqual(resp_list.content, b'{"dummy_make_white_square":"DummyImageToImageModel"}')
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
def test_i2i_inference_errors_model_not_found(self):
model_id = 'not_a_real_model'
......@@ -65,7 +70,10 @@ class TestApiFromAutomatedClient(unittest.TestCase):
def test_i2i_dummy_inference_by_api(self):
model = DummyImageToImageModel()
resp_load = requests.get(self.uri + f'models/{model.model_id}/load')
resp_load = requests.put(
self.uri + f'models/load',
params={'model_id': model.model_id}
)
self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}')
self.copy_input_file_to_server()
resp_infer = requests.put(
......
......@@ -42,9 +42,9 @@ class TestGetSessionObject(unittest.TestCase):
model_id = DummyImageToImageModel.model_id
success = sesh.load_model(model_id)
self.assertTrue(success)
loaded_models = sesh.describe_models()
loaded_models = sesh.describe_loaded_models()
self.assertTrue(model_id in loaded_models.keys())
self.assertEqual(
loaded_models[model_id],
loaded_models[model_id]['class'],
DummyImageToImageModel.__name__
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment