Skip to content
Snippets Groups Projects
Commit d331f923 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Changed superclass of dummy model

parent a29dc6bb
No related branches found
No related tags found
No related merge requests found
......@@ -18,16 +18,17 @@ def read_root():
@app.get('/models')
def list_active_models():
return session.models # TODO: include model type too
return session.describe_models()
@app.get('/models/{model_id}/load/')
def load_model(model_id: str, project_file: Path) -> Path: # does API autoencode path as JSON?
def load_model(model_id: str) -> dict:
if model_id in session.models.keys():
raise HTTPException(
status_code=409,
detail=f'Model with id {model_id} has already been loaded'
)
session
session.load_model(model_id)
return session.describe_models()
@app.post('/i2i/infer/{model_id}') # image file in, image file out
def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
......
......@@ -20,6 +20,7 @@ class Model(ABC):
else:
self.loaded = False
raise CouldNotLoadModelError()
return None
@abstractmethod
def load(self):
......@@ -44,6 +45,7 @@ class Model(ABC):
def reload(self):
self.load()
class ImageToImageModel(Model):
def __init__(self, **kwargs):
......@@ -51,7 +53,7 @@ class ImageToImageModel(Model):
Abstract class for models that receive an image and return an image of the same size
:param kwargs: variable length keyword arguments
"""
return super(**kwargs)
return super().__init__(**kwargs)
@abstractmethod
def infer(self, img, channel=None) -> (np.ndarray, dict):
......@@ -60,7 +62,7 @@ class ImageToImageModel(Model):
class IlastikImageToImageModel(ImageToImageModel):
pass
class DummyImageToImageModel(Model):
class DummyImageToImageModel(ImageToImageModel):
model_id = 'dummy_make_white_square'
......
......@@ -71,10 +71,15 @@ class Session(object):
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc()
assert mi.loaded
self.models.append(mi)
self.models['model_id'] = mi
return True
return False
def describe_models(self) -> dict:
return {
k: self.models[k].__class__.__name__
for k in self.models.keys()
}
def restart(self):
self.__init__()
......
......@@ -29,11 +29,15 @@ class TestApiFromAutomatedClient(unittest.TestCase):
def test_list_empty_loaded_models(self):
resp = requests.get(self.uri + 'models')
print(resp.content)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.content, b'{}')
def test_load_model(self):
resp = requests.get(self.uri + 'load_')
model_id = DummyImageToImageModel.model_id
resp = requests.get(self.uri + f'models/{model_id}/load')
self.assertEqual(resp.status_code, 200)
loaded = requests.get(self.uri + 'models')
self.assertEqual(loaded.content, b'{"model_id":"DummyImageToImageModel"}')
def test_i2i_inference_errors_model_not_sound(self):
model_id = 'not_a_real_model'
......
import unittest
from model_server.model import DummyImageToImageModel
from model_server.session import Session
class TestGetSessionObject(unittest.TestCase):
......@@ -32,3 +33,8 @@ class TestGetSessionObject(unittest.TestCase):
do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_session_load_model(self):
sesh = Session()
self.assertTrue(sesh.load_model(DummyImageToImageModel.model_id))
self.assertTrue('model_id' in sesh.models.keys())
self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment