diff --git a/api.py b/api.py index 0167a17c3f1a3196b7ba4b5c4c16bd750d6f7cbb..9518066f46722be53132db03c4508b8d156f6632 100644 --- a/api.py +++ b/api.py @@ -2,7 +2,9 @@ from typing import Dict from fastapi import FastAPI, HTTPException -from model_server.session import CouldNotFindModelError, Session +from model_server.ilastik import IlastikPixelClassifierModel, IlastikObjectClassifierModel +from model_server.model import DummyImageToImageModel +from model_server.session import Session from model_server.workflow import infer_image_to_image app = FastAPI(debug=True) @@ -20,22 +22,17 @@ def read_root(): def list_active_models(): return session.describe_loaded_models() -@app.put('/models/load/') -# def load_model(model_id: str, misc: Dict[str, str]) -> dict: -def load_model(model_id: str, misc: dict) -> dict: - if model_id in session.models.keys(): - raise HTTPException( - status_code=409, - detail=f'Model with id {model_id} has already been loaded' - ) - try: - session.load_model(model_id, params=misc) - except CouldNotFindModelError: - raise HTTPException( - status_code=404, - detail=f'Could not find {model_id} in defined models' - ) - return session.describe_loaded_models() +@app.put('/models/dummy/load/') +def load_dummy_model(params: str = None) -> dict: + return session.load_model(DummyImageToImageModel, params) + +@app.put('/models/ilastik/pixel_classification/load/') +def load_ilastik_pixel_classification_model(params: str) -> dict: + return session.load_model(IlastikPixelClassifierModel, params) + +@app.put('/models/ilastik/object_classification/load/') +def load_ilastik_object_classification_model(params: str) -> dict: + return session.load_model(IlastikObjectClassifierModel, params) @app.put('/i2i/infer/') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: diff --git a/model_server/model_registry.py b/model_server/model_registry.py deleted file mode 100644 index f6686dc08c0bcd9978bac2f241f7e25507da8ad8..0000000000000000000000000000000000000000 --- a/model_server/model_registry.py +++ /dev/null @@ -1,17 +0,0 @@ -import model_server.ilastik -import model_server.model - -def get_all_model_subclasses(): - """ - Recursively find all subclasses of Model - :return: set of all subclasses of Model - """ - - def get_all_subclasses_of(cc): - return set(cc.__subclasses__()).union( - [s for c in cc.__subclasses__() for s in get_all_subclasses_of(c)]) - - return get_all_subclasses_of(model_server.model.Model) - -if __name__ == '__main__': - print(get_all_model_subclasses()) \ No newline at end of file diff --git a/model_server/session.py b/model_server/session.py index 69a94fcb60e051faf08564fcde0fce6c862aafac..0ce71e1388dcf153ee4e24571a410c801ade691c 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -7,7 +7,6 @@ from typing import Dict from conf.server import paths from model_server.model import Model -from model_server.model_registry import get_all_model_subclasses from model_server.share import SharedImageDirectory from model_server.workflow import WorkflowRunRecord @@ -64,40 +63,30 @@ class Session(object): with open(self.manifest_json, 'w+') as fh: json.dump(record.dict(), fh) - def load_model(self, model_id: str, params: Dict[str, str] = None) -> bool: + + def load_model(self, ModelClass: Model, params: Dict[str, str] = None) -> dict: """ - Load an instance of first model class that matches model_id string - :param model_id: string that uniquely defines a class of model - :param params: optional parameters that are passed upon loading a model - :return: True if model successfully loaded, False if not + Load an instance of a given model class and attach to this session's model registry + :param ModelClass: subclass of Model + :param params: optional parameters that are passed to the model's construct + :return: dictionary that describes all currently loaded models """ - models = get_all_model_subclasses() - for mc in models: - if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: - try: - mi = mc(params=params) - assert mi.loaded - except: - raise CouldNotInstantiateModelError() - self.models[model_id] = { - 'object': mi, - 'params': params, - } - self.log_event(f'Loaded model {model_id}') - return True - raise CouldNotFindModelError( - f'Could not find {model_id} in:\n{models}', - ) - return False + mi = ModelClass(params=params) + assert mi.loaded, f'Error loading instance of {ModelClass.__name__}' + ii = 0 + def mid(i): return f'{ModelClass.__name__}_{ii:02d}' + while mid(ii) in self.models.keys(): + ii += 1 + self.models[mid(ii)] = { + 'object': mi, + 'params': params + } + self.log_event(f'Loaded model {mid}') + return self.describe_loaded_models() def describe_loaded_models(self) -> dict: - return { - k: { - 'class': self.models[k]['object'].__class__.__name__, - 'params': self.models[k]['params'], - } - for k in self.models.keys() - } + # TODO: explictly make this JSON-compatible + return self.models def restart(self): self.__init__() @@ -105,9 +94,6 @@ class Session(object): class Error(Exception): pass -class CouldNotFindModelError(Error): - pass - class InferenceRecordError(Error): pass diff --git a/tests/test_api.py b/tests/test_api.py index bcb366eca8fa82ab5828e27ae0c1c5a29d1ff595..4e7c4819482c21a2fef2e9bf563747d92f21dc97 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,5 +1,4 @@ from multiprocessing import Process - import requests import unittest @@ -50,10 +49,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass): def test_load_dummy_model(self): model_id = DummyImageToImageModel.model_id resp_load = requests.put( - self.uri + f'models/load', - params={'model_id': model_id} + self.uri + f'models/dummy/load', + # params={'misc': {'d': 'e'}} ) - self.assertEqual(resp_load.status_code, 200) + self.assertEqual(resp_load.status_code, 200, resp_load.json()) resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) rj = resp_list.json() diff --git a/tests/test_session.py b/tests/test_session.py index 99fd6c6dca75bc7a4b07b1f3f4f68bbf8837318c..2cfead8e44a970b9a89923ee0452f7c3c55027b3 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -37,14 +37,36 @@ class TestGetSessionObject(unittest.TestCase): do = json.load(fh) self.assertEqual(di.dict(), do, 'Manifest record is not correct') + def test_session_loads_model(self): sesh = Session() - model_id = DummyImageToImageModel.model_id - success = sesh.load_model(model_id) + MC = DummyImageToImageModel + success = sesh.load_model(MC) self.assertTrue(success) loaded_models = sesh.describe_loaded_models() - self.assertTrue(model_id in loaded_models.keys()) + self.assertTrue( + (MC.__name__ + '_00') in loaded_models.keys() + ) self.assertEqual( - loaded_models[model_id]['class'], - DummyImageToImageModel.__name__ - ) \ No newline at end of file + loaded_models[MC.__name__ + '_00']['object'].__class__, + MC + ) + + def test_session_loads_second_instance_of_same_model(self): + sesh = Session() + MC = DummyImageToImageModel + sesh.load_model(MC) + sesh.load_model(MC) + print(sesh.models.keys()) + self.assertIn(MC.__name__ + '_00', sesh.models.keys()) + self.assertIn(MC.__name__ + '_01', sesh.models.keys()) + + + def test_session_loads_model_with_params(self): + sesh = Session() + MC = DummyImageToImageModel + p = {'p1': 'abc'} + success = sesh.load_model(MC, params=p) + self.assertTrue(success) + loaded_models = sesh.describe_loaded_models() + self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p) \ No newline at end of file