diff --git a/api.py b/api.py index 512083625be5b0d4a1a56539efe5b20108a90d34..0640ad67ecfc43bf7b5ea5e263dff3c613822d40 100644 --- a/api.py +++ b/api.py @@ -20,13 +20,14 @@ def read_root(): def list_active_models(): return session.models # TODO: include model type too -@app.get('/models/ilastik/load/') +@app.get('/models/{model_id}/load/') def load_model(model_id: str, project_file: Path) -> Path: # does API autoencode path as JSON? if model_id in session.models.keys(): raise HTTPException( status_code=409, detail=f'Model with id {model_id} has already been loaded' ) + session @app.post('/i2i/infer/{model_id}') # image file in, image file out def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: @@ -41,7 +42,13 @@ def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: session.inbound / imgf, session.models[model_id], session.outbound, - channel=channel + channel=channel, + # TODO: optional callback for status reporting ) session.record_workflow_run(record) - return record \ No newline at end of file + return record + +# TODO: report out model inference status +@app.get('/i2i/status/{model_id}') +def status_model_inference(model_id: str) -> dict: + pass \ No newline at end of file diff --git a/model_server/model.py b/model_server/model.py index c9c3140b697bd4543e603b760408b718e8c45914..6eafdf3d084c93d1a146ccb0642d91c13fa621d6 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -5,6 +5,7 @@ import numpy as np from model_server.image import GenericImageFileAccessor + class Model(ABC): def __init__(self, autoload=True): @@ -14,9 +15,18 @@ class Model(ABC): :param autoload: automatically load model and dependencies into memory if True """ self.autoload = autoload + if self.load(): + self.loaded = True + else: + self.loaded = False + raise CouldNotLoadModelError() @abstractmethod def load(self): + """ + Abstract method that carries out the expectedly time-consuming step of loading a model into memory + :return: True if successful, else False + """ pass @abstractmethod @@ -55,7 +65,7 @@ class DummyImageToImageModel(Model): model_id = 'dummy_make_white_square' def load(self): - self.loaded = True + return True def infer(self, img: GenericImageFileAccessor, channel=None) -> (np.ndarray, dict): super().infer(img, channel) @@ -65,8 +75,12 @@ class DummyImageToImageModel(Model): result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 return (result, {'success': True}) + class Error(Exception): pass class ChannelTooHighError(Error): + pass + +class CouldNotLoadModelError(Error): pass \ No newline at end of file diff --git a/model_server/session.py b/model_server/session.py index 21e6633292406cf0f2651db8c205f66d26322ec1..17958bb9ed0c881fb4974350c5ce86f27180ffba 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -5,6 +5,7 @@ from pathlib import Path from time import strftime, localtime from conf.server import paths +from model_server.model import Model from model_server.share import SharedImageDirectory from model_server.workflow import WorkflowRunRecord @@ -34,7 +35,6 @@ class Session(object): self.manifest_json = self.where_records / f'{self.session_id}-manifest.json' open(self.manifest_json, 'w').close() # instantiate empty json file - @staticmethod def create_session_id(look_where: Path) -> str: """ @@ -61,6 +61,21 @@ class Session(object): with open(self.manifest_json, 'w+') as fh: json.dump(record.dict(), fh) + def load_model(self, model_id: str) -> bool: + """ + Load an instance of first model class that matches model_id string + :param model_id: + :return: True if model successfully loaded, False if not + """ + for mc in Model.__subclasses__(): + if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: + mi = mc() + assert mi.loaded + self.models.append(mi) + return True + return False + + def restart(self): self.__init__() diff --git a/model_server/workflow.py b/model_server/workflow.py index 5f1a4a9205755a4cdd06d2540bada30aaaa66940..888b33745c63f7e0913333f0465f337d7f3d6712 100644 --- a/model_server/workflow.py +++ b/model_server/workflow.py @@ -30,6 +30,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict: assert (img.shape_dict['T'] == 1) # run model inference + # TODO: call this async / await and report out infer status to optional callback ch = kwargs.get('channel') outdata, messages = model.infer(img, channel=ch) dt_inf = time() - t0 diff --git a/tests/test_api.py b/tests/test_api.py index 2bdd245195e29fd20b8eb923c035fd5776595239..f7c778c06658922345de248be5a978ad739eff6f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,7 +1,9 @@ from multiprocessing import Process import requests import unittest -import uvicorn + +from conf.testing import czifile, output_path +from model_server.model import DummyImageToImageModel class TestApiFromAutomatedClient(unittest.TestCase): def setUp(self) -> None: @@ -25,3 +27,25 @@ class TestApiFromAutomatedClient(unittest.TestCase): resp = requests.get(self.uri, ) self.assertEqual(resp.status_code, 200) + def test_list_empty_loaded_models(self): + resp = requests.get(self.uri + 'models') + print(resp.content) + self.assertEqual(resp.status_code, 200) + + def test_load_model(self): + resp = requests.get(self.uri + 'load_') + + def test_i2i_inference_errors_model_not_sound(self): + model_id = 'not_a_real_model' + resp = requests.post(self.uri + f'i2i/infer/{model_id}') + self.assertEqual(resp.status_code, 404) + + def test_i2i_dummy_inference_by_api(self): + model = DummyImageToImageModel() + model_id = model.model_id + resp = requests.post( + self.uri + f'/i2i/infer/{model_id}', + str(czifile['path']), + ) + print(resp) + self.assertEqual(resp.status_code, 200) \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py index 1920ea3b4d128a86c4909068404276276d9a8590..aff9db8639b8380f1ab89cf7f8f061ce7a1ea010 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,16 +1,30 @@ import unittest from conf.testing import czifile from model_server.image import CziImageFileAccessor -from model_server.model import DummyImageToImageModel +from model_server.model import DummyImageToImageModel, CouldNotLoadModelError class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) + def test_instantiate_model(self): + model = DummyImageToImageModel() + self.assertTrue(model.loaded) + def test_instantiate_model_with_nondefault_kwarg(self): model = DummyImageToImageModel(autoload=False) self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') + def test_raise_error_if_cannot_load_model(self): + class UnloadableDummyImageToImageModel(DummyImageToImageModel): + def load(self): + return False + + self.assertRaises( + CouldNotLoadModelError, + mi=UnloadableDummyImageToImageModel, + ) + def test_czifile_is_correct_shape(self): model = DummyImageToImageModel() img, _ = model.infer(self.cf, channel=1)