Skip to content
Snippets Groups Projects
Commit 3a7f41e0 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Updated API for loading models but still need to makes some fixes for tests to pass

parent 9010a5a1
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,9 @@ from typing import Dict ...@@ -2,7 +2,9 @@ from typing import Dict
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from model_server.session import CouldNotFindModelError, Session from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel
from model_server.model import DummyImageToImageModel
from model_server.session import Session
from model_server.workflow import infer_image_to_image from model_server.workflow import infer_image_to_image
app = FastAPI(debug=True) app = FastAPI(debug=True)
...@@ -20,22 +22,17 @@ def read_root(): ...@@ -20,22 +22,17 @@ def read_root():
def list_active_models(): def list_active_models():
return session.describe_loaded_models() return session.describe_loaded_models()
@app.put('/models/load/') @app.put('/models/dummy/load/')
# def load_model(model_id: str, misc: Dict[str, str]) -> dict: def load_dummy_model(params: str = None) -> dict:
def load_model(model_id: str, misc: dict) -> dict: return session.load_model(DummyImageToImageModel, params)
if model_id in session.models.keys():
raise HTTPException( @app.put('/models/ilastik/pixel_classification/load/')
status_code=409, def load_ilastik_pixel_classification_model(params: str) -> dict:
detail=f'Model with id {model_id} has already been loaded' return session.load_model(IlastikPixelClassifierModel, params)
)
try: @app.put('/models/ilastik/object_classification/load/')
session.load_model(model_id, params=misc) def load_ilastik_object_classification_model(params: str) -> dict:
except CouldNotFindModelError: return session.load_model(IlastikObjectClassifierModel, params)
raise HTTPException(
status_code=404,
detail=f'Could not find {model_id} in defined models'
)
return session.describe_loaded_models()
@app.put('/i2i/infer/') @app.put('/i2i/infer/')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
......
import model_server.ilastik
import model_server.model
def get_all_model_subclasses():
"""
Recursively find all subclasses of Model
:return: set of all subclasses of Model
"""
def get_all_subclasses_of(cc):
return set(cc.__subclasses__()).union(
[s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)])
return get_all_subclasses_of(model_server.model.Model)
if __name__ == '__main__':
print(get_all_model_subclasses())
\ No newline at end of file
...@@ -7,7 +7,6 @@ from typing import Dict ...@@ -7,7 +7,6 @@ from typing import Dict
from conf.server import paths from conf.server import paths
from model_server.model import Model from model_server.model import Model
from model_server.model_registry import get_all_model_subclasses
from model_server.share import SharedImageDirectory from model_server.share import SharedImageDirectory
from model_server.workflow import WorkflowRunRecord from model_server.workflow import WorkflowRunRecord
...@@ -64,40 +63,30 @@ class Session(object): ...@@ -64,40 +63,30 @@ class Session(object):
with open(self.manifest_json, 'w+') as fh: with open(self.manifest_json, 'w+') as fh:
json.dump(record.dict(), fh) json.dump(record.dict(), fh)
def load_model(self, model_id: str, params: Dict[str, str] = None) -> bool:
def load_model(self, ModelClass: Model, params: Dict[str, str] = None) -> dict:
""" """
Load an instance of first model class that matches model_id string Load an instance of a given model class and attach to this session's model registry
:param model_id: string that uniquely defines a class of model :param ModelClass: subclass of Model
:param params: optional parameters that are passed upon loading a model :param params: optional parameters that are passed to the model's construct
:return: True if model successfully loaded, False if not :return: dictionary that describes all currently loaded models
""" """
models = get_all_model_subclasses() mi = ModelClass(params=params)
for mc in models: assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: ii = 0
try: def mid(i): return f'{ModelClass.__name__}_{ii:02d}'
mi = mc(params=params) while mid(ii) in self.models.keys():
assert mi.loaded ii += 1
except: self.models[mid(ii)] = {
raise CouldNotInstantiateModelError() 'object': mi,
self.models[model_id] = { 'params': params
'object': mi, }
'params': params, self.log_event(f'Loaded model {mid}')
} return self.describe_loaded_models()
self.log_event(f'Loaded model {model_id}')
return True
raise CouldNotFindModelError(
f'Could not find {model_id} in:\n{models}',
)
return False
def describe_loaded_models(self) -> dict: def describe_loaded_models(self) -> dict:
return { # TODO: explictly make this JSON-compatible
k: { return self.models
'class': self.models[k]['object'].__class__.__name__,
'params': self.models[k]['params'],
}
for k in self.models.keys()
}
def restart(self): def restart(self):
self.__init__() self.__init__()
...@@ -105,9 +94,6 @@ class Session(object): ...@@ -105,9 +94,6 @@ class Session(object):
class Error(Exception): class Error(Exception):
pass pass
class CouldNotFindModelError(Error):
pass
class InferenceRecordError(Error): class InferenceRecordError(Error):
pass pass
......
from multiprocessing import Process from multiprocessing import Process
import requests import requests
import unittest import unittest
...@@ -50,10 +49,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass): ...@@ -50,10 +49,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
def test_load_dummy_model(self): def test_load_dummy_model(self):
model_id = DummyImageToImageModel.model_id model_id = DummyImageToImageModel.model_id
resp_load = requests.put( resp_load = requests.put(
self.uri + f'models/load', self.uri + f'models/dummy/load',
params={'model_id': model_id} # params={'misc': {'d': 'e'}}
) )
self.assertEqual(resp_load.status_code, 200) self.assertEqual(resp_load.status_code, 200, resp_load.json())
resp_list = requests.get(self.uri + 'models') resp_list = requests.get(self.uri + 'models')
self.assertEqual(resp_list.status_code, 200) self.assertEqual(resp_list.status_code, 200)
rj = resp_list.json() rj = resp_list.json()
......
...@@ -37,14 +37,36 @@ class TestGetSessionObject(unittest.TestCase): ...@@ -37,14 +37,36 @@ class TestGetSessionObject(unittest.TestCase):
do = json.load(fh) do = json.load(fh)
self.assertEqual(di.dict(), do, 'Manifest record is not correct') self.assertEqual(di.dict(), do, 'Manifest record is not correct')
def test_session_loads_model(self): def test_session_loads_model(self):
sesh = Session() sesh = Session()
model_id = DummyImageToImageModel.model_id MC = DummyImageToImageModel
success = sesh.load_model(model_id) success = sesh.load_model(MC)
self.assertTrue(success) self.assertTrue(success)
loaded_models = sesh.describe_loaded_models() loaded_models = sesh.describe_loaded_models()
self.assertTrue(model_id in loaded_models.keys()) self.assertTrue(
(MC.__name__ + '_00') in loaded_models.keys()
)
self.assertEqual( self.assertEqual(
loaded_models[model_id]['class'], loaded_models[MC.__name__ + '_00']['object'].__class__,
DummyImageToImageModel.__name__ MC
) )
\ No newline at end of file
def test_session_loads_second_instance_of_same_model(self):
sesh = Session()
MC = DummyImageToImageModel
sesh.load_model(MC)
sesh.load_model(MC)
print(sesh.models.keys())
self.assertIn(MC.__name__ + '_00', sesh.models.keys())
self.assertIn(MC.__name__ + '_01', sesh.models.keys())
def test_session_loads_model_with_params(self):
sesh = Session()
MC = DummyImageToImageModel
p = {'p1': 'abc'}
success = sesh.load_model(MC, params=p)
self.assertTrue(success)
loaded_models = sesh.describe_loaded_models()
self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment