Skip to content
Snippets Groups Projects
Commit a29dc6bb authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Implemented model-loading logic at session level

parent d7642645
No related branches found
No related tags found
No related merge requests found
......@@ -20,13 +20,14 @@ def read_root():
def list_active_models():
return session.models # TODO: include model type too
@app.get('/models/ilastik/load/')
@app.get('/models/{model_id}/load/')
def load_model(model_id: str, project_file: Path) -> Path: # does API autoencode path as JSON?
if model_id in session.models.keys():
raise HTTPException(
status_code=409,
detail=f'Model with id {model_id} has already been loaded'
)
session
@app.post('/i2i/infer/{model_id}') # image file in, image file out
def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
......@@ -41,7 +42,13 @@ def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
session.inbound / imgf,
session.models[model_id],
session.outbound,
channel=channel
channel=channel,
# TODO: optional callback for status reporting
)
session.record_workflow_run(record)
return record
\ No newline at end of file
return record
# TODO: report out model inference status
@app.get('/i2i/status/{model_id}')
def status_model_inference(model_id: str) -> dict:
pass
\ No newline at end of file
......@@ -5,6 +5,7 @@ import numpy as np
from model_server.image import GenericImageFileAccessor
class Model(ABC):
def __init__(self, autoload=True):
......@@ -14,9 +15,18 @@ class Model(ABC):
:param autoload: automatically load model and dependencies into memory if True
"""
self.autoload = autoload
if self.load():
self.loaded = True
else:
self.loaded = False
raise CouldNotLoadModelError()
@abstractmethod
def load(self):
"""
Abstract method that carries out the expectedly time-consuming step of loading a model into memory
:return: True if successful, else False
"""
pass
@abstractmethod
......@@ -55,7 +65,7 @@ class DummyImageToImageModel(Model):
model_id = 'dummy_make_white_square'
def load(self):
self.loaded = True
return True
def infer(self, img: GenericImageFileAccessor, channel=None) -> (np.ndarray, dict):
super().infer(img, channel)
......@@ -65,8 +75,12 @@ class DummyImageToImageModel(Model):
result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255
return (result, {'success': True})
class Error(Exception):
pass
class ChannelTooHighError(Error):
pass
class CouldNotLoadModelError(Error):
pass
\ No newline at end of file
......@@ -5,6 +5,7 @@ from pathlib import Path
from time import strftime, localtime
from conf.server import paths
from model_server.model import Model
from model_server.share import SharedImageDirectory
from model_server.workflow import WorkflowRunRecord
......@@ -34,7 +35,6 @@ class Session(object):
self.manifest_json = self.where_records / f'{self.session_id}-manifest.json'
open(self.manifest_json, 'w').close() # instantiate empty json file
@staticmethod
def create_session_id(look_where: Path) -> str:
"""
......@@ -61,6 +61,21 @@ class Session(object):
with open(self.manifest_json, 'w+') as fh:
json.dump(record.dict(), fh)
def load_model(self, model_id: str) -> bool:
"""
Load an instance of first model class that matches model_id string
:param model_id:
:return: True if model successfully loaded, False if not
"""
for mc in Model.__subclasses__():
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
mi = mc()
assert mi.loaded
self.models.append(mi)
return True
return False
def restart(self):
self.__init__()
......
......@@ -30,6 +30,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
assert (img.shape_dict['T'] == 1)
# run model inference
# TODO: call this async / await and report out infer status to optional callback
ch = kwargs.get('channel')
outdata, messages = model.infer(img, channel=ch)
dt_inf = time() - t0
......
from multiprocessing import Process
import requests
import unittest
import uvicorn
from conf.testing import czifile, output_path
from model_server.model import DummyImageToImageModel
class TestApiFromAutomatedClient(unittest.TestCase):
def setUp(self) -> None:
......@@ -25,3 +27,25 @@ class TestApiFromAutomatedClient(unittest.TestCase):
resp = requests.get(self.uri, )
self.assertEqual(resp.status_code, 200)
def test_list_empty_loaded_models(self):
resp = requests.get(self.uri + 'models')
print(resp.content)
self.assertEqual(resp.status_code, 200)
def test_load_model(self):
resp = requests.get(self.uri + 'load_')
def test_i2i_inference_errors_model_not_sound(self):
model_id = 'not_a_real_model'
resp = requests.post(self.uri + f'i2i/infer/{model_id}')
self.assertEqual(resp.status_code, 404)
def test_i2i_dummy_inference_by_api(self):
model = DummyImageToImageModel()
model_id = model.model_id
resp = requests.post(
self.uri + f'/i2i/infer/{model_id}',
str(czifile['path']),
)
print(resp)
self.assertEqual(resp.status_code, 200)
\ No newline at end of file
import unittest
from conf.testing import czifile
from model_server.image import CziImageFileAccessor
from model_server.model import DummyImageToImageModel
from model_server.model import DummyImageToImageModel, CouldNotLoadModelError
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path'])
def test_instantiate_model(self):
model = DummyImageToImageModel()
self.assertTrue(model.loaded)
def test_instantiate_model_with_nondefault_kwarg(self):
model = DummyImageToImageModel(autoload=False)
self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')
def test_raise_error_if_cannot_load_model(self):
class UnloadableDummyImageToImageModel(DummyImageToImageModel):
def load(self):
return False
self.assertRaises(
CouldNotLoadModelError,
mi=UnloadableDummyImageToImageModel,
)
def test_czifile_is_correct_shape(self):
model = DummyImageToImageModel()
img, _ = model.infer(self.cf, channel=1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment