From 125847bcaa9a6e6fe53a4fdb523229a138792c3a Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 28 Aug 2023 15:08:13 +0200
Subject: [PATCH] Reconfigured inference endpoint as pure PUT; can now pass
 params to model inference method

---
 api.py                  | 18 +++++++++---------
 model_server/model.py   | 24 +++++++++++++++---------
 model_server/session.py | 28 +++++++++++++++++++++-------
 tests/test_api.py       | 14 +++++++++++---
 tests/test_session.py   |  4 ++--
 5 files changed, 58 insertions(+), 30 deletions(-)

diff --git a/api.py b/api.py
index f5d7976f..89e42d8f 100644
--- a/api.py
+++ b/api.py
@@ -1,4 +1,4 @@
-from pathlib import Path
+from typing import Dict
 
 from fastapi import FastAPI, HTTPException
 
@@ -18,21 +18,21 @@ def read_root():
 
 @app.get('/models')
 def list_active_models():
-    return session.describe_models()
+    return session.describe_loaded_models()
 
-@app.get('/models/{model_id}/load/')
-def load_model(model_id: str) -> dict:
+@app.put('/models/load/')
+def load_model(model_id: str, params: Dict[str, str] = None) -> dict:
     if model_id in session.models.keys():
         raise HTTPException(
             status_code=409,
             detail=f'Model with id {model_id} has already been loaded'
         )
-    session.load_model(model_id)
-    return session.describe_models()
+    session.load_model(model_id, params=params)
+    return session.describe_loaded_models()
 
 @app.put('/i2i/infer/{model_id}') # image file in, image file out
 def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
-    if model_id not in session.describe_models().keys():
+    if model_id not in session.describe_loaded_models().keys():
         raise HTTPException(
             status_code=409,
             detail=f'Model {model_id} has not been loaded'
@@ -43,10 +43,10 @@ def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
             status_code=404,
             detail=f'Could not find file:\n{inpath}'
         )
-
+    model = session.models[model_id]['object']
     record = infer_image_to_image(
         inpath,
-        session.models[model_id],
+        session.models[model_id]['object'],
         session.outbound.path,
         channel=channel,
         # TODO: optional callback for status reporting
diff --git a/model_server/model.py b/model_server/model.py
index 7c9d84e1..d483744a 100644
--- a/model_server/model.py
+++ b/model_server/model.py
@@ -1,5 +1,6 @@
 from abc import ABC, abstractmethod
 from math import floor
+import os
 
 import numpy as np
 
@@ -8,13 +9,15 @@ from model_server.image import GenericImageFileAccessor
 
 class Model(ABC):
 
-    def __init__(self, autoload=True):
+    def __init__(self, autoload=True, params=None):
         """
         Abstract base class for an inference model that uses image data as an input.
 
         :param autoload: automatically load model and dependencies into memory if True
+        :param params: Dict[str, str] of arguments e.g. configuration files required to load model
         """
         self.autoload = autoload
+        self.params = params
         if self.load():
             self.loaded = True
         else:
@@ -58,20 +61,20 @@ class Model(ABC):
 
 
 class ImageToImageModel(Model):
-
-    def __init__(self, **kwargs):
-        """
-        Abstract class for models that receive an image and return an image of the same size
-        :param kwargs: variable length keyword arguments
-        """
-        return super().__init__(**kwargs)
+    """
+    Abstract class for models that receive an image and return an image of the same size
+    """
 
     @abstractmethod
     def infer(self, img, channel=None) -> (np.ndarray, dict):
         super().infer(img, channel)
 
 class IlastikImageToImageModel(ImageToImageModel):
-    pass
+    def load(self):
+        if 'project_file' not in self.params or not os.path.exists(self.params['project_file']):
+            raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
+        self.project_file = self.params['project_file']
+
 
 class DummyImageToImageModel(ImageToImageModel):
 
@@ -96,4 +99,7 @@ class ChannelTooHighError(Error):
     pass
 
 class CouldNotLoadModelError(Error):
+    pass
+
+class ParameterExpectedError(Error):
     pass
\ No newline at end of file
diff --git a/model_server/session.py b/model_server/session.py
index f3b65d62..5ab27fe9 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -3,6 +3,7 @@ import os
 
 from pathlib import Path
 from time import strftime, localtime
+from typing import Dict
 
 from conf.server import paths
 from model_server.model import Model
@@ -61,27 +62,37 @@ class Session(object):
         with open(self.manifest_json, 'w+') as fh:
             json.dump(record.dict(), fh)
 
-    def load_model(self, model_id: str) -> bool:
+    def load_model(self, model_id: str, params: Dict[str, str] = None) -> bool:
         """
         Load an instance of first model class that matches model_id string
-        :param model_id:
+        :param model_id: string that uniquely defines a class of model
+        :param params: optional parameters that are passed upon loading a model
         :return: True if model successfully loaded, False if not
         """
         models = Model.get_all_subclasses()
         for mc in models:
             if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
-                mi = mc()
-                assert mi.loaded
-                self.models[model_id] = mi
+                try:
+                    mi = mc(params)
+                    assert mi.loaded
+                except:
+                    raise CouldNotInstantiateModelError()
+                self.models[model_id] = {
+                    'object': mi,
+                    'params': params,
+                }
                 return True
         raise CouldNotFindModelError(
             f'Could not find {model_id} in:\n{models}',
         )
         return False
 
-    def describe_models(self) -> dict:
+    def describe_loaded_models(self) -> dict:
         return {
-            k: self.models[k].__class__.__name__
+            k: {
+                'class': self.models[k]['object'].__class__.__name__,
+                'params': self.models[k]['params'],
+            }
             for k in self.models.keys()
         }
 
@@ -95,4 +106,7 @@ class CouldNotFindModelError(Error):
     pass
 
 class InferenceRecordError(Error):
+    pass
+
+class CouldNotInstantiateModelError(Error):
     pass
\ No newline at end of file
diff --git a/tests/test_api.py b/tests/test_api.py
index ecc0296c..e290d163 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -48,11 +48,16 @@ class TestApiFromAutomatedClient(unittest.TestCase):
 
     def test_load_model(self):
         model_id = DummyImageToImageModel.model_id
-        resp_load = requests.get(self.uri + f'models/{model_id}/load')
+        resp_load = requests.put(
+            self.uri + f'models/load',
+            params={'model_id': model_id}
+        )
         self.assertEqual(resp_load.status_code, 200)
         resp_list = requests.get(self.uri + 'models')
         self.assertEqual(resp_list.status_code, 200)
-        self.assertEqual(resp_list.content, b'{"dummy_make_white_square":"DummyImageToImageModel"}')
+        rj = resp_list.json()
+        self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
+
 
     def test_i2i_inference_errors_model_not_found(self):
         model_id = 'not_a_real_model'
@@ -65,7 +70,10 @@ class TestApiFromAutomatedClient(unittest.TestCase):
 
     def test_i2i_dummy_inference_by_api(self):
         model = DummyImageToImageModel()
-        resp_load = requests.get(self.uri + f'models/{model.model_id}/load')
+        resp_load = requests.put(
+            self.uri + f'models/load',
+            params={'model_id': model.model_id}
+        )
         self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}')
         self.copy_input_file_to_server()
         resp_infer = requests.put(
diff --git a/tests/test_session.py b/tests/test_session.py
index 2538d3d2..99fd6c6d 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -42,9 +42,9 @@ class TestGetSessionObject(unittest.TestCase):
         model_id = DummyImageToImageModel.model_id
         success = sesh.load_model(model_id)
         self.assertTrue(success)
-        loaded_models = sesh.describe_models()
+        loaded_models = sesh.describe_loaded_models()
         self.assertTrue(model_id in loaded_models.keys())
         self.assertEqual(
-            loaded_models[model_id],
+            loaded_models[model_id]['class'],
             DummyImageToImageModel.__name__
         )
\ No newline at end of file
-- 
GitLab