diff --git a/api.py b/api.py index f5d7976fdd7f0bb154159bb29aeb8816f3a725eb..89e42d8fde48a22e8447964da59e32fe388fa5ee 100644 --- a/api.py +++ b/api.py @@ -1,4 +1,4 @@ -from pathlib import Path +from typing import Dict from fastapi import FastAPI, HTTPException @@ -18,21 +18,21 @@ def read_root(): @app.get('/models') def list_active_models(): - return session.describe_models() + return session.describe_loaded_models() -@app.get('/models/{model_id}/load/') -def load_model(model_id: str) -> dict: +@app.put('/models/load/') +def load_model(model_id: str, params: Dict[str, str] = None) -> dict: if model_id in session.models.keys(): raise HTTPException( status_code=409, detail=f'Model with id {model_id} has already been loaded' ) - session.load_model(model_id) - return session.describe_models() + session.load_model(model_id, params=params) + return session.describe_loaded_models() @app.put('/i2i/infer/{model_id}') # image file in, image file out def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: - if model_id not in session.describe_models().keys(): + if model_id not in session.describe_loaded_models().keys(): raise HTTPException( status_code=409, detail=f'Model {model_id} has not been loaded' @@ -43,10 +43,10 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: status_code=404, detail=f'Could not find file:\n{inpath}' ) - + model = session.models[model_id]['object'] record = infer_image_to_image( inpath, - session.models[model_id], + session.models[model_id]['object'], session.outbound.path, channel=channel, # TODO: optional callback for status reporting diff --git a/model_server/model.py b/model_server/model.py index 7c9d84e1184640703a6a0574d1f4c18d6b19cbac..d483744a7594e67328f0fcc03fdebfd9d3a36a5f 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from math import floor +import os import numpy as np @@ -8,13 +9,15 @@ from model_server.image import GenericImageFileAccessor class Model(ABC): - def __init__(self, autoload=True): + def __init__(self, autoload=True, params=None): """ Abstract base class for an inference model that uses image data as an input. :param autoload: automatically load model and dependencies into memory if True + :param params: Dict[str, str] of arguments e.g. configuration files required to load model """ self.autoload = autoload + self.params = params if self.load(): self.loaded = True else: @@ -58,20 +61,20 @@ class Model(ABC): class ImageToImageModel(Model): - - def __init__(self, **kwargs): - """ - Abstract class for models that receive an image and return an image of the same size - :param kwargs: variable length keyword arguments - """ - return super().__init__(**kwargs) + """ + Abstract class for models that receive an image and return an image of the same size + """ @abstractmethod def infer(self, img, channel=None) -> (np.ndarray, dict): super().infer(img, channel) class IlastikImageToImageModel(ImageToImageModel): - pass + def load(self): + if 'project_file' not in self.params or not os.path.exists(self.params['project_file']): + raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file') + self.project_file = self.params['project_file'] + class DummyImageToImageModel(ImageToImageModel): @@ -96,4 +99,7 @@ class ChannelTooHighError(Error): pass class CouldNotLoadModelError(Error): + pass + +class ParameterExpectedError(Error): pass \ No newline at end of file diff --git a/model_server/session.py b/model_server/session.py index f3b65d62d087b9920c04c50596dc8d86cc67d739..5ab27fe99c9b35cd05a0a85012793b948a25447a 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -3,6 +3,7 @@ import os from pathlib import Path from time import strftime, localtime +from typing import Dict from conf.server import paths from model_server.model import Model @@ -61,27 +62,37 @@ class Session(object): with open(self.manifest_json, 'w+') as fh: json.dump(record.dict(), fh) - def load_model(self, model_id: str) -> bool: + def load_model(self, model_id: str, params: Dict[str, str] = None) -> bool: """ Load an instance of first model class that matches model_id string - :param model_id: + :param model_id: string that uniquely defines a class of model + :param params: optional parameters that are passed upon loading a model :return: True if model successfully loaded, False if not """ models = Model.get_all_subclasses() for mc in models: if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: - mi = mc() - assert mi.loaded - self.models[model_id] = mi + try: + mi = mc(params) + assert mi.loaded + except: + raise CouldNotInstantiateModelError() + self.models[model_id] = { + 'object': mi, + 'params': params, + } return True raise CouldNotFindModelError( f'Could not find {model_id} in:\n{models}', ) return False - def describe_models(self) -> dict: + def describe_loaded_models(self) -> dict: return { - k: self.models[k].__class__.__name__ + k: { + 'class': self.models[k]['object'].__class__.__name__, + 'params': self.models[k]['params'], + } for k in self.models.keys() } @@ -95,4 +106,7 @@ class CouldNotFindModelError(Error): pass class InferenceRecordError(Error): + pass + +class CouldNotInstantiateModelError(Error): pass \ No newline at end of file diff --git a/tests/test_api.py b/tests/test_api.py index ecc0296c600107b2df6ad4bebf2f57406bb01987..e290d16350f6e89495cc36bd63cbe949610006a3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -48,11 +48,16 @@ class TestApiFromAutomatedClient(unittest.TestCase): def test_load_model(self): model_id = DummyImageToImageModel.model_id - resp_load = requests.get(self.uri + f'models/{model_id}/load') + resp_load = requests.put( + self.uri + f'models/load', + params={'model_id': model_id} + ) self.assertEqual(resp_load.status_code, 200) resp_list = requests.get(self.uri + 'models') self.assertEqual(resp_list.status_code, 200) - self.assertEqual(resp_list.content, b'{"dummy_make_white_square":"DummyImageToImageModel"}') + rj = resp_list.json() + self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel') + def test_i2i_inference_errors_model_not_found(self): model_id = 'not_a_real_model' @@ -65,7 +70,10 @@ class TestApiFromAutomatedClient(unittest.TestCase): def test_i2i_dummy_inference_by_api(self): model = DummyImageToImageModel() - resp_load = requests.get(self.uri + f'models/{model.model_id}/load') + resp_load = requests.put( + self.uri + f'models/load', + params={'model_id': model.model_id} + ) self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}') self.copy_input_file_to_server() resp_infer = requests.put( diff --git a/tests/test_session.py b/tests/test_session.py index 2538d3d2e455b1437b91715f3de0b997a2e3cced..99fd6c6dca75bc7a4b07b1f3f4f68bbf8837318c 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -42,9 +42,9 @@ class TestGetSessionObject(unittest.TestCase): model_id = DummyImageToImageModel.model_id success = sesh.load_model(model_id) self.assertTrue(success) - loaded_models = sesh.describe_models() + loaded_models = sesh.describe_loaded_models() self.assertTrue(model_id in loaded_models.keys()) self.assertEqual( - loaded_models[model_id], + loaded_models[model_id]['class'], DummyImageToImageModel.__name__ ) \ No newline at end of file