From 097a6ca2a508e80034f94aecef7f8174335d36f8 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 6 May 2024 14:38:51 +0200
Subject: [PATCH] Models now generically use Pydyantic models for
 parameterization

---
 model_server/base/models.py  |  7 ++++---
 model_server/base/session.py |  4 ++--
 tests/test_session.py        | 20 +++++++++++++-------
 3 files changed, 19 insertions(+), 12 deletions(-)

diff --git a/model_server/base/models.py b/model_server/base/models.py
index 61d064b1..228bcb66 100644
--- a/model_server/base/models.py
+++ b/model_server/base/models.py
@@ -10,15 +10,16 @@ from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAc
 
 class Model(ABC):
 
-    def __init__(self, autoload=True, params=Union[BaseModel, None]):
+    def __init__(self, autoload=True, params: BaseModel = None):
         """
         Abstract base class for an inference model that uses image data as an input.
 
         :param autoload: automatically load model and dependencies into memory if True
-        :param params: Dict[str, str] of arguments e.g. configuration files required to load model
+        :param params: (optional) BaseModel of model parameters e.g. configuration files required to load model
         """
         self.autoload = autoload
-        self.params = params.dict()
+        if params:
+            self.params = params.dict()
         self.loaded = False
         if not autoload:
             return None
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 7639edf2..9b7b1acf 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -140,7 +140,7 @@ class Session(object, metaclass=Singleton):
     def log_error(self, msg):
         logger.error(msg)
 
-    def load_model(self, ModelClass: Model, params: Union[BaseModel, None]) -> dict:
+    def load_model(self, ModelClass: Model, params: Union[BaseModel, None] = None) -> dict:
         """
         Load an instance of a given model class and attach to this session's model registry
         :param ModelClass: subclass of Model
@@ -160,7 +160,7 @@ class Session(object, metaclass=Singleton):
         key = mid(ii)
         self.models[key] = {
             'object': mi,
-            'params': mi.params
+            'params': getattr(mi, 'params', None)
         }
         self.log_info(f'Loaded model {key}')
         return key
diff --git a/tests/test_session.py b/tests/test_session.py
index aafda3c2..31d62903 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -1,6 +1,7 @@
 import json
 from os.path import exists
 import pathlib
+from pydantic import BaseModel
 import unittest
 
 from model_server.base.models import DummySemanticSegmentationModel
@@ -95,7 +96,9 @@ class TestGetSessionObject(unittest.TestCase):
 
     def test_session_loads_model_with_params(self):
         MC = DummySemanticSegmentationModel
-        p1 = {'p1': 'abc'}
+        class _PM(BaseModel):
+            p: str
+        p1 = _PM(p='abc')
         success = self.sesh.load_model(MC, params=p1)
         self.assertTrue(success)
         loaded_models = self.sesh.describe_loaded_models()
@@ -103,19 +106,22 @@ class TestGetSessionObject(unittest.TestCase):
         self.assertEqual(loaded_models[mid]['params'], p1)
 
         # load a second model and confirm that the first is locatable by its param entry
-        p2 = {'p2': 'def'}
+        p2 = _PM(p='def')
         self.sesh.load_model(MC, params=p2)
-        find_mid = self.sesh.find_param_in_loaded_models('p1', 'abc')
+        find_mid = self.sesh.find_param_in_loaded_models('p', 'abc')
         self.assertEqual(mid, find_mid)
         self.assertEqual(self.sesh.describe_loaded_models()[mid]['params'], p1)
 
     def test_session_finds_existing_model_with_different_path_formats(self):
         MC = DummySemanticSegmentationModel
-        p1 = {'path': 'c:\\windows\\dummy.pa'}
-        p2 = {'path': 'c:/windows/dummy.pa'}
+        class _PM(BaseModel):
+            path: str
+
+        p1 = _PM(path='c:\\windows\\dummy.pa')
+        p2 = _PM(path='c:/windows/dummy.pa')
         mid = self.sesh.load_model(MC, params=p1)
-        assert pathlib.Path(p1['path']) == pathlib.Path(p2['path'])
-        find_mid = self.sesh.find_param_in_loaded_models('path', p2['path'], is_path=True)
+        assert pathlib.Path(p1.path) == pathlib.Path(p2.path)
+        find_mid = self.sesh.find_param_in_loaded_models('path', p2.path, is_path=True)
         self.assertEqual(mid, find_mid)
 
     def test_change_output_path(self):
-- 
GitLab