diff --git a/model_server/base/models.py b/model_server/base/models.py index 61d064b1565df40d7d887913a99a8131563f71f7..228bcb66a7451837eaa6fd3b23a4fe489c335b95 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -10,15 +10,16 @@ from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAc class Model(ABC): - def __init__(self, autoload=True, params=Union[BaseModel, None]): + def __init__(self, autoload=True, params: BaseModel = None): """ Abstract base class for an inference model that uses image data as an input. :param autoload: automatically load model and dependencies into memory if True - :param params: Dict[str, str] of arguments e.g. configuration files required to load model + :param params: (optional) BaseModel of model parameters e.g. configuration files required to load model """ self.autoload = autoload - self.params = params.dict() + if params: + self.params = params.dict() self.loaded = False if not autoload: return None diff --git a/model_server/base/session.py b/model_server/base/session.py index 7639edf2328145442d8997a55969f986a058987b..9b7b1acfa89b5d502b5579cc03c14525027d14a2 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -140,7 +140,7 @@ class Session(object, metaclass=Singleton): def log_error(self, msg): logger.error(msg) - def load_model(self, ModelClass: Model, params: Union[BaseModel, None]) -> dict: + def load_model(self, ModelClass: Model, params: Union[BaseModel, None] = None) -> dict: """ Load an instance of a given model class and attach to this session's model registry :param ModelClass: subclass of Model @@ -160,7 +160,7 @@ class Session(object, metaclass=Singleton): key = mid(ii) self.models[key] = { 'object': mi, - 'params': mi.params + 'params': getattr(mi, 'params', None) } self.log_info(f'Loaded model {key}') return key diff --git a/tests/test_session.py b/tests/test_session.py index aafda3c27f7146edc9e5235e660bc311d579e938..31d6290392537922f87aa1e09e2f61317e9e9537 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,6 +1,7 @@ import json from os.path import exists import pathlib +from pydantic import BaseModel import unittest from model_server.base.models import DummySemanticSegmentationModel @@ -95,7 +96,9 @@ class TestGetSessionObject(unittest.TestCase): def test_session_loads_model_with_params(self): MC = DummySemanticSegmentationModel - p1 = {'p1': 'abc'} + class _PM(BaseModel): + p: str + p1 = _PM(p='abc') success = self.sesh.load_model(MC, params=p1) self.assertTrue(success) loaded_models = self.sesh.describe_loaded_models() @@ -103,19 +106,22 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual(loaded_models[mid]['params'], p1) # load a second model and confirm that the first is locatable by its param entry - p2 = {'p2': 'def'} + p2 = _PM(p='def') self.sesh.load_model(MC, params=p2) - find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc') + find_mid = self.sesh.find_param_in_loaded_models('p', 'abc') self.assertEqual(mid, find_mid) self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1) def test_session_finds_existing_model_with_different_path_formats(self): MC = DummySemanticSegmentationModel - p1 = {'path': 'c:\\windows\\dummy.pa'} - p2 = {'path': 'c:/windows/dummy.pa'} + class _PM(BaseModel): + path: str + + p1 = _PM(path='c:\\windows\\dummy.pa') + p2 = _PM(path='c:/windows/dummy.pa') mid = self.sesh.load_model(MC, params=p1) - assert pathlib.Path(p1['path']) == pathlib.Path(p2['path']) - find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True) + assert pathlib.Path(p1.path) == pathlib.Path(p2.path) + find_mid = self.sesh.find_param_in_loaded_models('path', p2.path, is_path=True) self.assertEqual(mid, find_mid) def test_change_output_path(self):