Skip to content
Snippets Groups Projects
Commit 097a6ca2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Models now generically use Pydyantic models for parameterization

parent fff8a36f
No related branches found
No related tags found
2 merge requests!50Release 2024.06.03,!42Models are now initialized with pydantic models
......@@ -10,15 +10,16 @@ from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAc
class Model(ABC):
def __init__(self, autoload=True, params=Union[BaseModel, None]):
def __init__(self, autoload=True, params: BaseModel = None):
"""
Abstract base class for an inference model that uses image data as an input.
:param autoload: automatically load model and dependencies into memory if True
:param params: Dict[str, str] of arguments e.g. configuration files required to load model
:param params: (optional) BaseModel of model parameters e.g. configuration files required to load model
"""
self.autoload = autoload
self.params = params.dict()
if params:
self.params = params.dict()
self.loaded = False
if not autoload:
return None
......
......@@ -140,7 +140,7 @@ class Session(object, metaclass=Singleton):
def log_error(self, msg):
logger.error(msg)
def load_model(self, ModelClass: Model, params: Union[BaseModel, None]) -> dict:
def load_model(self, ModelClass: Model, params: Union[BaseModel, None] = None) -> dict:
"""
Load an instance of a given model class and attach to this session's model registry
:param ModelClass: subclass of Model
......@@ -160,7 +160,7 @@ class Session(object, metaclass=Singleton):
key = mid(ii)
self.models[key] = {
'object': mi,
'params': mi.params
'params': getattr(mi, 'params', None)
}
self.log_info(f'Loaded model {key}')
return key
......
import json
from os.path import exists
import pathlib
from pydantic import BaseModel
import unittest
from model_server.base.models import DummySemanticSegmentationModel
......@@ -95,7 +96,9 @@ class TestGetSessionObject(unittest.TestCase):
def test_session_loads_model_with_params(self):
MC = DummySemanticSegmentationModel
p1 = {'p1': 'abc'}
class _PM(BaseModel):
p: str
p1 = _PM(p='abc')
success = self.sesh.load_model(MC, params=p1)
self.assertTrue(success)
loaded_models = self.sesh.describe_loaded_models()
......@@ -103,19 +106,22 @@ class TestGetSessionObject(unittest.TestCase):
self.assertEqual(loaded_models[mid]['params'], p1)
# load a second model and confirm that the first is locatable by its param entry
p2 = {'p2': 'def'}
p2 = _PM(p='def')
self.sesh.load_model(MC, params=p2)
find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc')
find_mid = self.sesh.find_param_in_loaded_models('p', 'abc')
self.assertEqual(mid, find_mid)
self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)
def test_session_finds_existing_model_with_different_path_formats(self):
MC = DummySemanticSegmentationModel
p1 = {'path': 'c:\\windows\\dummy.pa'}
p2 = {'path': 'c:/windows/dummy.pa'}
class _PM(BaseModel):
path: str
p1 = _PM(path='c:\\windows\\dummy.pa')
p2 = _PM(path='c:/windows/dummy.pa')
mid = self.sesh.load_model(MC, params=p1)
assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
assert pathlib.Path(p1.path) == pathlib.Path(p2.path)
find_mid = self.sesh.find_param_in_loaded_models('path', p2.path, is_path=True)
self.assertEqual(mid, find_mid)
def test_change_output_path(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment