Skip to content
Snippets Groups Projects
Commit 125847bc authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Reconfigured inference endpoint as pure PUT; can now pass params to model inference method

parent e2e7ea21
No related branches found
No related tags found
No related merge requests found
from pathlib import Path from typing import Dict
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
...@@ -18,21 +18,21 @@ def read_root(): ...@@ -18,21 +18,21 @@ def read_root():
@app.get('/models') @app.get('/models')
def list_active_models(): def list_active_models():
return session.describe_models() return session.describe_loaded_models()
@app.get('/models/{model_id}/load/') @app.put('/models/load/')
def load_model(model_id: str) -> dict: def load_model(model_id: str, params: Dict[str, str] = None) -> dict:
if model_id in session.models.keys(): if model_id in session.models.keys():
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
detail=f'Model with id {model_id} has already been loaded' detail=f'Model with id {model_id} has already been loaded'
) )
session.load_model(model_id) session.load_model(model_id, params=params)
return session.describe_models() return session.describe_loaded_models()
@app.put('/i2i/infer/{model_id}') # image file in, image file out @app.put('/i2i/infer/{model_id}') # image file in, image file out
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
if model_id not in session.describe_models().keys(): if model_id not in session.describe_loaded_models().keys():
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
detail=f'Model {model_id} has not been loaded' detail=f'Model {model_id} has not been loaded'
...@@ -43,10 +43,10 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: ...@@ -43,10 +43,10 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
status_code=404, status_code=404,
detail=f'Could not find file:\n{inpath}' detail=f'Could not find file:\n{inpath}'
) )
model = session.models[model_id]['object']
record = infer_image_to_image( record = infer_image_to_image(
inpath, inpath,
session.models[model_id], session.models[model_id]['object'],
session.outbound.path, session.outbound.path,
channel=channel, channel=channel,
# TODO: optional callback for status reporting # TODO: optional callback for status reporting
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from math import floor from math import floor
import os
import numpy as np import numpy as np
...@@ -8,13 +9,15 @@ from model_server.image import GenericImageFileAccessor ...@@ -8,13 +9,15 @@ from model_server.image import GenericImageFileAccessor
class Model(ABC): class Model(ABC):
def __init__(self, autoload=True): def __init__(self, autoload=True, params=None):
""" """
Abstract base class for an inference model that uses image data as an input. Abstract base class for an inference model that uses image data as an input.
:param autoload: automatically load model and dependencies into memory if True :param autoload: automatically load model and dependencies into memory if True
:param params: Dict[str, str] of arguments e.g. configuration files required to load model
""" """
self.autoload = autoload self.autoload = autoload
self.params = params
if self.load(): if self.load():
self.loaded = True self.loaded = True
else: else:
...@@ -58,20 +61,20 @@ class Model(ABC): ...@@ -58,20 +61,20 @@ class Model(ABC):
class ImageToImageModel(Model): class ImageToImageModel(Model):
"""
def __init__(self, **kwargs): Abstract class for models that receive an image and return an image of the same size
""" """
Abstract class for models that receive an image and return an image of the same size
:param kwargs: variable length keyword arguments
"""
return super().__init__(**kwargs)
@abstractmethod @abstractmethod
def infer(self, img, channel=None) -> (np.ndarray, dict): def infer(self, img, channel=None) -> (np.ndarray, dict):
super().infer(img, channel) super().infer(img, channel)
class IlastikImageToImageModel(ImageToImageModel): class IlastikImageToImageModel(ImageToImageModel):
pass def load(self):
if 'project_file' not in self.params or not os.path.exists(self.params['project_file']):
raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
self.project_file = self.params['project_file']
class DummyImageToImageModel(ImageToImageModel): class DummyImageToImageModel(ImageToImageModel):
...@@ -96,4 +99,7 @@ class ChannelTooHighError(Error): ...@@ -96,4 +99,7 @@ class ChannelTooHighError(Error):
pass pass
class CouldNotLoadModelError(Error): class CouldNotLoadModelError(Error):
pass
class ParameterExpectedError(Error):
pass pass
\ No newline at end of file
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
from pathlib import Path from pathlib import Path
from time import strftime, localtime from time import strftime, localtime
from typing import Dict
from conf.server import paths from conf.server import paths
from model_server.model import Model from model_server.model import Model
...@@ -61,27 +62,37 @@ class Session(object): ...@@ -61,27 +62,37 @@ class Session(object):
with open(self.manifest_json, 'w+') as fh: with open(self.manifest_json, 'w+') as fh:
json.dump(record.dict(), fh) json.dump(record.dict(), fh)
def load_model(self, model_id: str) -> bool: def load_model(self, model_id: str, params: Dict[str, str] = None) -> bool:
""" """
Load an instance of first model class that matches model_id string Load an instance of first model class that matches model_id string
:param model_id: :param model_id: string that uniquely defines a class of model
:param params: optional parameters that are passed upon loading a model
:return: True if model successfully loaded, False if not :return: True if model successfully loaded, False if not
""" """
models = Model.get_all_subclasses() models = Model.get_all_subclasses()
for mc in models: for mc in models:
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc() try:
assert mi.loaded mi = mc(params)
self.models[model_id] = mi assert mi.loaded
except:
raise CouldNotInstantiateModelError()
self.models[model_id] = {
'object': mi,
'params': params,
}
return True return True
raise CouldNotFindModelError( raise CouldNotFindModelError(
f'Could not find {model_id} in:\n{models}', f'Could not find {model_id} in:\n{models}',
) )
return False return False
def describe_models(self) -> dict: def describe_loaded_models(self) -> dict:
return { return {
k: self.models[k].__class__.__name__ k: {
'class': self.models[k]['object'].__class__.__name__,
'params': self.models[k]['params'],
}
for k in self.models.keys() for k in self.models.keys()
} }
...@@ -95,4 +106,7 @@ class CouldNotFindModelError(Error): ...@@ -95,4 +106,7 @@ class CouldNotFindModelError(Error):
pass pass
class InferenceRecordError(Error): class InferenceRecordError(Error):
pass
class CouldNotInstantiateModelError(Error):
pass pass
\ No newline at end of file
...@@ -48,11 +48,16 @@ class TestApiFromAutomatedClient(unittest.TestCase): ...@@ -48,11 +48,16 @@ class TestApiFromAutomatedClient(unittest.TestCase):
def test_load_model(self): def test_load_model(self):
model_id = DummyImageToImageModel.model_id model_id = DummyImageToImageModel.model_id
resp_load = requests.get(self.uri + f'models/{model_id}/load') resp_load = requests.put(
self.uri + f'models/load',
params={'model_id': model_id}
)
self.assertEqual(resp_load.status_code, 200) self.assertEqual(resp_load.status_code, 200)
resp_list = requests.get(self.uri + 'models') resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200) self.assertEqual(resp_list.status_code, 200)
self.assertEqual(resp_list.content, b'{"dummy_make_white_square":"DummyImageToImageModel"}') rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
def test_i2i_inference_errors_model_not_found(self): def test_i2i_inference_errors_model_not_found(self):
model_id = 'not_a_real_model' model_id = 'not_a_real_model'
...@@ -65,7 +70,10 @@ class TestApiFromAutomatedClient(unittest.TestCase): ...@@ -65,7 +70,10 @@ class TestApiFromAutomatedClient(unittest.TestCase):
def test_i2i_dummy_inference_by_api(self): def test_i2i_dummy_inference_by_api(self):
model = DummyImageToImageModel() model = DummyImageToImageModel()
resp_load = requests.get(self.uri + f'models/{model.model_id}/load') resp_load = requests.put(
self.uri + f'models/load',
params={'model_id': model.model_id}
)
self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}') self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}')
self.copy_input_file_to_server() self.copy_input_file_to_server()
resp_infer = requests.put( resp_infer = requests.put(
......
...@@ -42,9 +42,9 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -42,9 +42,9 @@ class TestGetSessionObject(unittest.TestCase):
model_id = DummyImageToImageModel.model_id model_id = DummyImageToImageModel.model_id
success = sesh.load_model(model_id) success = sesh.load_model(model_id)
self.assertTrue(success) self.assertTrue(success)
loaded_models = sesh.describe_models() loaded_models = sesh.describe_loaded_models()
self.assertTrue(model_id in loaded_models.keys()) self.assertTrue(model_id in loaded_models.keys())
self.assertEqual( self.assertEqual(
loaded_models[model_id], loaded_models[model_id]['class'],
DummyImageToImageModel.__name__ DummyImageToImageModel.__name__
) )
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment