Skip to content
Snippets Groups Projects
Commit 3a7f41e0 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Updated API for loading models but still need to makes some fixes for tests to pass

parent 9010a5a1
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,9 @@ from typing import Dict
from fastapi import FastAPI, HTTPException
from model_server.session import CouldNotFindModelError, Session
from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel
from model_server.model import DummyImageToImageModel
from model_server.session import Session
from model_server.workflow import infer_image_to_image
app = FastAPI(debug=True)
......@@ -20,22 +22,17 @@ def read_root():
def list_active_models():
return session.describe_loaded_models()
@app.put('/models/load/')
# def load_model(model_id: str, misc: Dict[str, str]) -> dict:
def load_model(model_id: str, misc: dict) -> dict:
if model_id in session.models.keys():
raise HTTPException(
status_code=409,
detail=f'Model with id {model_id} has already been loaded'
)
try:
session.load_model(model_id, params=misc)
except CouldNotFindModelError:
raise HTTPException(
status_code=404,
detail=f'Could not find {model_id} in defined models'
)
return session.describe_loaded_models()
@app.put('/models/dummy/load/')
def load_dummy_model(params: str = None) -> dict:
return session.load_model(DummyImageToImageModel, params)
@app.put('/models/ilastik/pixel_classification/load/')
def load_ilastik_pixel_classification_model(params: str) -> dict:
return session.load_model(IlastikPixelClassifierModel, params)
@app.put('/models/ilastik/object_classification/load/')
def load_ilastik_object_classification_model(params: str) -> dict:
return session.load_model(IlastikObjectClassifierModel, params)
@app.put('/i2i/infer/')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
......
import model_server.ilastik
import model_server.model
def get_all_model_subclasses():
"""
Recursively find all subclasses of Model
:return: set of all subclasses of Model
"""
def get_all_subclasses_of(cc):
return set(cc.__subclasses__()).union(
[s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)])
return get_all_subclasses_of(model_server.model.Model)
if __name__ == '__main__':
print(get_all_model_subclasses())
\ No newline at end of file
......@@ -7,7 +7,6 @@ from typing import Dict
from conf.server import paths
from model_server.model import Model
from model_server.model_registry import get_all_model_subclasses
from model_server.share import SharedImageDirectory
from model_server.workflow import WorkflowRunRecord
......@@ -64,40 +63,30 @@ class Session(object):
with open(self.manifest_json, 'w+') as fh:
json.dump(record.dict(), fh)
def load_model(self, model_id: str, params: Dict[str, str] = None) -> bool:
def load_model(self, ModelClass: Model, params: Dict[str, str] = None) -> dict:
"""
Load an instance of first model class that matches model_id string
:param model_id: string that uniquely defines a class of model
:param params: optional parameters that are passed upon loading a model
:return: True if model successfully loaded, False if not
Load an instance of a given model class and attach to this session's model registry
:param ModelClass: subclass of Model
:param params: optional parameters that are passed to the model's construct
:return: dictionary that describes all currently loaded models
"""
models = get_all_model_subclasses()
for mc in models:
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
try:
mi = mc(params=params)
assert mi.loaded
except:
raise CouldNotInstantiateModelError()
self.models[model_id] = {
'object': mi,
'params': params,
}
self.log_event(f'Loaded model {model_id}')
return True
raise CouldNotFindModelError(
f'Could not find {model_id} in:\n{models}',
)
return False
mi = ModelClass(params=params)
assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
ii = 0
def mid(i): return f'{ModelClass.__name__}_{ii:02d}'
while mid(ii) in self.models.keys():
ii += 1
self.models[mid(ii)] = {
'object': mi,
'params': params
}
self.log_event(f'Loaded model {mid}')
return self.describe_loaded_models()
def describe_loaded_models(self) -> dict:
return {
k: {
'class': self.models[k]['object'].__class__.__name__,
'params': self.models[k]['params'],
}
for k in self.models.keys()
}
# TODO: explictly make this JSON-compatible
return self.models
def restart(self):
self.__init__()
......@@ -105,9 +94,6 @@ class Session(object):
class Error(Exception):
pass
class CouldNotFindModelError(Error):
pass
class InferenceRecordError(Error):
pass
......
from multiprocessing import Process
import requests
import unittest
......@@ -50,10 +49,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
def test_load_dummy_model(self):
model_id = DummyImageToImageModel.model_id
resp_load = requests.put(
self.uri + f'models/load',
params={'model_id': model_id}
self.uri + f'models/dummy/load',
# params={'misc': {'d': 'e'}}
)
self.assertEqual(resp_load.status_code, 200)
self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json()
......
......@@ -37,14 +37,36 @@ class TestGetSessionObject(unittest.TestCase):
do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_session_loads_model(self):
sesh = Session()
model_id = DummyImageToImageModel.model_id
success = sesh.load_model(model_id)
MC = DummyImageToImageModel
success = sesh.load_model(MC)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
self.assertTrue(model_id in loaded_models.keys())
self.assertTrue(
(MC.__name__ + '_00') in loaded_models.keys()
)
self.assertEqual(
loaded_models[model_id]['class'],
DummyImageToImageModel.__name__
)
\ No newline at end of file
loaded_models[MC.__name__ + '_00']['object'].__class__,
MC
)
def test_session_loads_second_instance_of_same_model(self):
sesh = Session()
MC = DummyImageToImageModel
sesh.load_model(MC)
sesh.load_model(MC)
print(sesh.models.keys())
self.assertIn(MC.__name__ + '_00', sesh.models.keys())
self.assertIn(MC.__name__ + '_01', sesh.models.keys())
def test_session_loads_model_with_params(self):
sesh = Session()
MC = DummyImageToImageModel
p = {'p1': 'abc'}
success = sesh.load_model(MC, params=p)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment