Skip to content
Snippets Groups Projects
Commit a29dc6bb authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Implemented model-loading logic at session level

parent d7642645
No related branches found
No related tags found
No related merge requests found
...@@ -20,13 +20,14 @@ def read_root(): ...@@ -20,13 +20,14 @@ def read_root():
def list_active_models(): def list_active_models():
return session.models # TODO: include model type too return session.models # TODO: include model type too
@app.get('/models/ilastik/load/') @app.get('/models/{model_id}/load/')
def load_model(model_id: str, project_file: Path) -> Path: # does API autoencode path as JSON? def load_model(model_id: str, project_file: Path) -> Path: # does API autoencode path as JSON?
if model_id in session.models.keys(): if model_id in session.models.keys():
raise HTTPException( raise HTTPException(
status_code=409, status_code=409,
detail=f'Model with id {model_id} has already been loaded' detail=f'Model with id {model_id} has already been loaded'
) )
session
@app.post('/i2i/infer/{model_id}') # image file in, image file out @app.post('/i2i/infer/{model_id}') # image file in, image file out
def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
...@@ -41,7 +42,13 @@ def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: ...@@ -41,7 +42,13 @@ def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
session.inbound / imgf, session.inbound / imgf,
session.models[model_id], session.models[model_id],
session.outbound, session.outbound,
channel=channel channel=channel,
# TODO: optional callback for status reporting
) )
session.record_workflow_run(record) session.record_workflow_run(record)
return record return record
\ No newline at end of file
# TODO: report out model inference status
@app.get('/i2i/status/{model_id}')
def status_model_inference(model_id: str) -> dict:
pass
\ No newline at end of file
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from model_server.image import GenericImageFileAccessor from model_server.image import GenericImageFileAccessor
class Model(ABC): class Model(ABC):
def __init__(self, autoload=True): def __init__(self, autoload=True):
...@@ -14,9 +15,18 @@ class Model(ABC): ...@@ -14,9 +15,18 @@ class Model(ABC):
:param autoload: automatically load model and dependencies into memory if True :param autoload: automatically load model and dependencies into memory if True
""" """
self.autoload = autoload self.autoload = autoload
if self.load():
self.loaded = True
else:
self.loaded = False
raise CouldNotLoadModelError()
@abstractmethod @abstractmethod
def load(self): def load(self):
"""
Abstract method that carries out the expectedly time-consuming step of loading a model into memory
:return: True if successful, else False
"""
pass pass
@abstractmethod @abstractmethod
...@@ -55,7 +65,7 @@ class DummyImageToImageModel(Model): ...@@ -55,7 +65,7 @@ class DummyImageToImageModel(Model):
model_id = 'dummy_make_white_square' model_id = 'dummy_make_white_square'
def load(self): def load(self):
self.loaded = True return True
def infer(self, img: GenericImageFileAccessor, channel=None) -> (np.ndarray, dict): def infer(self, img: GenericImageFileAccessor, channel=None) -> (np.ndarray, dict):
super().infer(img, channel) super().infer(img, channel)
...@@ -65,8 +75,12 @@ class DummyImageToImageModel(Model): ...@@ -65,8 +75,12 @@ class DummyImageToImageModel(Model):
result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255
return (result, {'success': True}) return (result, {'success': True})
class Error(Exception): class Error(Exception):
pass pass
class ChannelTooHighError(Error): class ChannelTooHighError(Error):
pass
class CouldNotLoadModelError(Error):
pass pass
\ No newline at end of file
...@@ -5,6 +5,7 @@ from pathlib import Path ...@@ -5,6 +5,7 @@ from pathlib import Path
from time import strftime, localtime from time import strftime, localtime
from conf.server import paths from conf.server import paths
from model_server.model import Model
from model_server.share import SharedImageDirectory from model_server.share import SharedImageDirectory
from model_server.workflow import WorkflowRunRecord from model_server.workflow import WorkflowRunRecord
...@@ -34,7 +35,6 @@ class Session(object): ...@@ -34,7 +35,6 @@ class Session(object):
self.manifest_json = self.where_records / f'{self.session_id}-manifest.json' self.manifest_json = self.where_records / f'{self.session_id}-manifest.json'
open(self.manifest_json, 'w').close() # instantiate empty json file open(self.manifest_json, 'w').close() # instantiate empty json file
@staticmethod @staticmethod
def create_session_id(look_where: Path) -> str: def create_session_id(look_where: Path) -> str:
""" """
...@@ -61,6 +61,21 @@ class Session(object): ...@@ -61,6 +61,21 @@ class Session(object):
with open(self.manifest_json, 'w+') as fh: with open(self.manifest_json, 'w+') as fh:
json.dump(record.dict(), fh) json.dump(record.dict(), fh)
def load_model(self, model_id: str) -> bool:
"""
Load an instance of first model class that matches model_id string
:param model_id:
:return: True if model successfully loaded, False if not
"""
for mc in Model.__subclasses__():
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc()
assert mi.loaded
self.models.append(mi)
return True
return False
def restart(self): def restart(self):
self.__init__() self.__init__()
......
...@@ -30,6 +30,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict: ...@@ -30,6 +30,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
assert (img.shape_dict['T'] == 1) assert (img.shape_dict['T'] == 1)
# run model inference # run model inference
# TODO: call this async / await and report out infer status to optional callback
ch = kwargs.get('channel') ch = kwargs.get('channel')
outdata, messages = model.infer(img, channel=ch) outdata, messages = model.infer(img, channel=ch)
dt_inf = time() - t0 dt_inf = time() - t0
......
from multiprocessing import Process from multiprocessing import Process
import requests import requests
import unittest import unittest
import uvicorn
from conf.testing import czifile, output_path
from model_server.model import DummyImageToImageModel
class TestApiFromAutomatedClient(unittest.TestCase): class TestApiFromAutomatedClient(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
...@@ -25,3 +27,25 @@ class TestApiFromAutomatedClient(unittest.TestCase): ...@@ -25,3 +27,25 @@ class TestApiFromAutomatedClient(unittest.TestCase):
resp = requests.get(self.uri, ) resp = requests.get(self.uri, )
self.assertEqual(resp.status_code, 200) self.assertEqual(resp.status_code, 200)
def test_list_empty_loaded_models(self):
resp = requests.get(self.uri + 'models')
print(resp.content)
self.assertEqual(resp.status_code, 200)
def test_load_model(self):
resp = requests.get(self.uri + 'load_')
def test_i2i_inference_errors_model_not_sound(self):
model_id = 'not_a_real_model'
resp = requests.post(self.uri + f'i2i/infer/{model_id}')
self.assertEqual(resp.status_code, 404)
def test_i2i_dummy_inference_by_api(self):
model = DummyImageToImageModel()
model_id = model.model_id
resp = requests.post(
self.uri + f'/i2i/infer/{model_id}',
str(czifile['path']),
)
print(resp)
self.assertEqual(resp.status_code, 200)
\ No newline at end of file
import unittest import unittest
from conf.testing import czifile from conf.testing import czifile
from model_server.image import CziImageFileAccessor from model_server.image import CziImageFileAccessor
from model_server.model import DummyImageToImageModel from model_server.model import DummyImageToImageModel, CouldNotLoadModelError
class TestCziImageFileAccess(unittest.TestCase): class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path']) self.cf = CziImageFileAccessor(czifile['path'])
def test_instantiate_model(self):
model = DummyImageToImageModel()
self.assertTrue(model.loaded)
def test_instantiate_model_with_nondefault_kwarg(self): def test_instantiate_model_with_nondefault_kwarg(self):
model = DummyImageToImageModel(autoload=False) model = DummyImageToImageModel(autoload=False)
self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')
def test_raise_error_if_cannot_load_model(self):
class UnloadableDummyImageToImageModel(DummyImageToImageModel):
def load(self):
return False
self.assertRaises(
CouldNotLoadModelError,
mi=UnloadableDummyImageToImageModel,
)
def test_czifile_is_correct_shape(self): def test_czifile_is_correct_shape(self):
model = DummyImageToImageModel() model = DummyImageToImageModel()
img, _ = model.infer(self.cf, channel=1) img, _ = model.infer(self.cf, channel=1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment