Skip to content
Snippets Groups Projects
Commit e147a54f authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Session returns JSON-compliant model list

parent 3a7f41e0
No related branches found
No related tags found
No related merge requests found
......@@ -23,8 +23,8 @@ def list_active_models():
return session.describe_loaded_models()
@app.put('/models/dummy/load/')
def load_dummy_model(params: str = None) -> dict:
return session.load_model(DummyImageToImageModel, params)
def load_dummy_model() -> dict:
return session.load_model(DummyImageToImageModel)
@app.put('/models/ilastik/pixel_classification/load/')
def load_ilastik_pixel_classification_model(params: str) -> dict:
......
......@@ -85,8 +85,13 @@ class Session(object):
return self.describe_loaded_models()
def describe_loaded_models(self) -> dict:
# TODO: explictly make this JSON-compatible
return self.models
return {
k: {
'class': self.models[k]['object'].__class__.__name__,
'params': self.models[k]['params'],
}
for k in self.models.keys()
}
def restart(self):
self.__init__()
......
......@@ -47,17 +47,15 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
self.assertEqual(resp.content, b'{}')
def test_load_dummy_model(self):
model_id = DummyImageToImageModel.model_id
model_key = DummyImageToImageModel.__name__ + '_00'
resp_load = requests.put(
self.uri + f'models/dummy/load',
# params={'misc': {'d': 'e'}}
)
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
return model_id
self.assertEqual(rj[model_key]['class'], 'DummyImageToImageModel')
def test_respond_with_error_when_invalid_model_loaded(self):
model_id = 'not_a_real_model'
......@@ -69,7 +67,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
print(resp.content)
def test_respond_with_error_when_invalid_filepath_requested(self):
model_id = self.test_load_dummy_model()
# model_id = self.test_load_dummy_model()
resp = requests.put(
self.uri + f'i2i/infer/',
params={
......
......@@ -44,12 +44,13 @@ class TestGetSessionObject(unittest.TestCase):
success = sesh.load_model(MC)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
print(loaded_models)
self.assertTrue(
(MC.__name__ + '_00') in loaded_models.keys()
)
self.assertEqual(
loaded_models[MC.__name__ + '_00']['object'].__class__,
MC
loaded_models[MC.__name__ + '_00']['class'],
MC.__name__
)
def test_session_loads_second_instance_of_same_model(self):
......@@ -69,4 +70,4 @@ class TestGetSessionObject(unittest.TestCase):
success = sesh.load_model(MC, params=p)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p)
\ No newline at end of file
self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment