Skip to content
Snippets Groups Projects
Commit fe4dda4f authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Option to skip loading a model if its project file is already in a loaded one

parent 3cc6ffd0
No related branches found
No related tags found
No related merge requests found
from fastapi import FastAPI, HTTPException
from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel
from model_server.ilastik import IlastikImageToImageModel, IlastikPixelClassifierModel, IlastikObjectClassifierModel
from model_server.model import DummyImageToImageModel, ParameterExpectedError
from model_server.session import Session
from model_server.workflow import infer_image_to_image
......@@ -38,7 +38,18 @@ def list_active_models():
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummyImageToImageModel)}
def load_ilastik_model(model_class, project_file):
def load_ilastik_model(model_class: IlastikImageToImageModel, project_file: str, duplicate=True) -> dict:
"""
Load an ilastik model of a given class and project filename.
:param model_class:
:param project_file: (*.ilp) ilastik project filename
:param duplicate: load another instance of the same project file if True; return existing one if false
:return: dictionary with single key describing model's ID
"""
if not duplicate:
existing_model = session.find_param_in_loaded_models('project_file', project_file)
if existing_model is not None:
return existing_model
try:
result = {
'model_id': session.load_model(
......@@ -54,12 +65,12 @@ def load_ilastik_model(model_class, project_file):
return result
@app.put('/models/ilastik/pixel_classification/load/')
def load_ilastik_pixel_classification_model(project_file: str) -> dict:
return load_ilastik_model(IlastikPixelClassifierModel, project_file)
def load_ilastik_pixel_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikPixelClassifierModel, project_file, duplicate=duplicate)
@app.put('/models/ilastik/object_classification/load/')
def load_ilastik_object_classification_model(project_file: str) -> dict:
return load_ilastik_model(IlastikObjectClassifierModel, project_file)
def load_ilastik_object_classification_model(project_file: str, duplicate: bool = True) -> dict:
return load_ilastik_model(IlastikObjectClassifierModel, project_file, duplicate=duplicate)
def validate_workflow_inputs(model_ids, inpaths):
for mid in model_ids:
......
......@@ -125,6 +125,16 @@ class Session(object):
for k in self.models.keys()
}
def find_param_in_loaded_models(self, key: str, value: str) -> dict:
"""
Returns first instance of loaded model where key and value match with .params field, or None
"""
models = self.describe_loaded_models()
for mid, det in models.items():
if det.get('params').get(key) == value:
return {mid: det}
return None
def restart(self, **kwargs):
self.__init__(**kwargs)
......
......@@ -132,6 +132,29 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel')
return model_id
def test_load_another_ilastik_pixel_model(self):
model_id = self.test_load_ilastik_pixel_model()
resp_list_1st = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
resp_load_2nd = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
params={
'project_file': str(conf.testing.ilastik['pixel_classifier']),
'duplicate': True,
},
)
resp_list_2nd = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
resp_load_3rd = requests.put(
self.uri + 'models/ilastik/pixel_classification/load/',
params={
'project_file': str(conf.testing.ilastik['pixel_classifier']),
'duplicate': False,
},
)
resp_list_3rd = requests.get(self.uri + 'models').json()
self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)
def test_load_ilastik_object_model(self):
resp_load = requests.put(
......
......@@ -79,8 +79,17 @@ class TestGetSessionObject(unittest.TestCase):
def test_session_loads_model_with_params(self):
sesh = Session()
MC = DummyImageToImageModel
p = {'p1': 'abc'}
success = sesh.load_model(MC, params=p)
p1 = {'p1': 'abc'}
success = sesh.load_model(MC, params=p1)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p)
mid = MC.__name__ + '_00'
self.assertEqual(loaded_models[mid]['params'], p1)
# load a second model and confirm that the first is locatable by its param entry
p2 = {'p2': 'def'}
sesh.load_model(MC, params=p2)
find_kv = sesh.find_param_in_loaded_models('p1', 'abc')
self.assertEqual(len(find_kv), 1)
self.assertEqual(find_kv[mid]['params'], p1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment