Skip to content
Snippets Groups Projects
Commit d331f923 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Changed superclass of dummy model

parent a29dc6bb
No related branches found
No related tags found
No related merge requests found
...@@ -18,16 +18,17 @@ def read_root(): ...@@ -18,16 +18,17 @@ def read_root():
@app.get('/models') @app.get('/models')
def list_active_models(): def list_active_models():
return session.models # TODO: include model type too return session.describe_models()
@app.get('/models/{model_id}/load/') @app.get('/models/{model_id}/load/')
def load_model(model_id: str, project_file: Path) -> Path: # does API autoencode path as JSON? def load_model(model_id: str) -> dict:
if model_id in session.models.keys(): if model_id in session.models.keys():
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
detail=f'Model with id {model_id} has already been loaded' detail=f'Model with id {model_id} has already been loaded'
) )
session session.load_model(model_id)
return session.describe_models()
@app.post('/i2i/infer/{model_id}') # image file in, image file out @app.post('/i2i/infer/{model_id}') # image file in, image file out
def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
......
...@@ -20,6 +20,7 @@ class Model(ABC): ...@@ -20,6 +20,7 @@ class Model(ABC):
else: else:
self.loaded = False self.loaded = False
raise CouldNotLoadModelError() raise CouldNotLoadModelError()
return None
@abstractmethod @abstractmethod
def load(self): def load(self):
...@@ -44,6 +45,7 @@ class Model(ABC): ...@@ -44,6 +45,7 @@ class Model(ABC):
def reload(self): def reload(self):
self.load() self.load()
class ImageToImageModel(Model): class ImageToImageModel(Model):
def __init__(self, **kwargs): def __init__(self, **kwargs):
...@@ -51,7 +53,7 @@ class ImageToImageModel(Model): ...@@ -51,7 +53,7 @@ class ImageToImageModel(Model):
Abstract class for models that receive an image and return an image of the same size Abstract class for models that receive an image and return an image of the same size
:param kwargs: variable length keyword arguments :param kwargs: variable length keyword arguments
""" """
return super(**kwargs) return super().__init__(**kwargs)
@abstractmethod @abstractmethod
def infer(self, img, channel=None) -> (np.ndarray, dict): def infer(self, img, channel=None) -> (np.ndarray, dict):
...@@ -60,7 +62,7 @@ class ImageToImageModel(Model): ...@@ -60,7 +62,7 @@ class ImageToImageModel(Model):
class IlastikImageToImageModel(ImageToImageModel): class IlastikImageToImageModel(ImageToImageModel):
pass pass
class DummyImageToImageModel(Model): class DummyImageToImageModel(ImageToImageModel):
model_id = 'dummy_make_white_square' model_id = 'dummy_make_white_square'
......
...@@ -71,10 +71,15 @@ class Session(object): ...@@ -71,10 +71,15 @@ class Session(object):
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc() mi = mc()
assert mi.loaded assert mi.loaded
self.models.append(mi) self.models['model_id'] = mi
return True return True
return False return False
def describe_models(self) -> dict:
return {
k: self.models[k].__class__.__name__
for k in self.models.keys()
}
def restart(self): def restart(self):
self.__init__() self.__init__()
......
...@@ -29,11 +29,15 @@ class TestApiFromAutomatedClient(unittest.TestCase): ...@@ -29,11 +29,15 @@ class TestApiFromAutomatedClient(unittest.TestCase):
def test_list_empty_loaded_models(self): def test_list_empty_loaded_models(self):
resp = requests.get(self.uri + 'models') resp = requests.get(self.uri + 'models')
print(resp.content)
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.content, b'{}')
def test_load_model(self): def test_load_model(self):
resp = requests.get(self.uri + 'load_') model_id = DummyImageToImageModel.model_id
resp = requests.get(self.uri + f'models/{model_id}/load')
self.assertEqual(resp.status_code, 200)
loaded = requests.get(self.uri + 'models')
self.assertEqual(loaded.content, b'{"model_id":"DummyImageToImageModel"}')
def test_i2i_inference_errors_model_not_sound(self): def test_i2i_inference_errors_model_not_sound(self):
model_id = 'not_a_real_model' model_id = 'not_a_real_model'
......
import unittest import unittest
from model_server.model import DummyImageToImageModel
from model_server.session import Session from model_server.session import Session
class TestGetSessionObject(unittest.TestCase): class TestGetSessionObject(unittest.TestCase):
...@@ -32,3 +33,8 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -32,3 +33,8 @@ class TestGetSessionObject(unittest.TestCase):
do = json.load(fh) do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct') self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_session_load_model(self):
sesh = Session()
self.assertTrue(sesh.load_model(DummyImageToImageModel.model_id))
self.assertTrue('model_id' in sesh.models.keys())
self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment