From a29dc6bb0d716281a52313fd90ee325817a4e198 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 24 Aug 2023 14:48:59 +0200
Subject: [PATCH] Implemented model-loading logic at session level

---
 api.py                   | 13 ++++++++++---
 model_server/model.py    | 16 +++++++++++++++-
 model_server/session.py  | 17 ++++++++++++++++-
 model_server/workflow.py |  1 +
 tests/test_api.py        | 26 +++++++++++++++++++++++++-
 tests/test_model.py      | 16 +++++++++++++++-
 6 files changed, 82 insertions(+), 7 deletions(-)

diff --git a/api.py b/api.py
index 51208362..0640ad67 100644
--- a/api.py
+++ b/api.py
@@ -20,13 +20,14 @@ def read_root():
 def list_active_models():
     return session.models # TODO: include model type too
 
-@app.get('/models/ilastik/load/')
+@app.get('/models/{model_id}/load/')
 def load_model(model_id: str, project_file: Path) -> Path: # does API autoencode path as JSON?
     if model_id in session.models.keys():
         raise HTTPException(
             status_code=409,
             detail=f'Model with id {model_id} has already been loaded'
         )
+    session
 
 @app.post('/i2i/infer/{model_id}') # image file in, image file out
 def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
@@ -41,7 +42,13 @@ def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
         session.inbound / imgf,
         session.models[model_id],
         session.outbound,
-        channel=channel
+        channel=channel,
+        # TODO: optional callback for status reporting
     )
     session.record_workflow_run(record)
-    return record
\ No newline at end of file
+    return record
+
+# TODO: report out model inference status
+@app.get('/i2i/status/{model_id}')
+def status_model_inference(model_id: str) -> dict:
+    pass
\ No newline at end of file
diff --git a/model_server/model.py b/model_server/model.py
index c9c3140b..6eafdf3d 100644
--- a/model_server/model.py
+++ b/model_server/model.py
@@ -5,6 +5,7 @@ import numpy as np
 
 from model_server.image import GenericImageFileAccessor
 
+
 class Model(ABC):
 
     def __init__(self, autoload=True):
@@ -14,9 +15,18 @@ class Model(ABC):
         :param autoload: automatically load model and dependencies into memory if True
         """
         self.autoload = autoload
+        if self.load():
+            self.loaded = True
+        else:
+            self.loaded = False
+            raise CouldNotLoadModelError()
 
     @abstractmethod
     def load(self):
+        """
+        Abstract method that carries out the expectedly time-consuming step of loading a model into memory
+        :return: True if successful, else False
+        """
         pass
 
     @abstractmethod
@@ -55,7 +65,7 @@ class DummyImageToImageModel(Model):
     model_id = 'dummy_make_white_square'
 
     def load(self):
-        self.loaded = True
+        return True
 
     def infer(self, img: GenericImageFileAccessor, channel=None) -> (np.ndarray, dict):
         super().infer(img, channel)
@@ -65,8 +75,12 @@ class DummyImageToImageModel(Model):
         result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255
         return (result, {'success': True})
 
+
 class Error(Exception):
     pass
 
 class ChannelTooHighError(Error):
+    pass
+
+class CouldNotLoadModelError(Error):
     pass
\ No newline at end of file
diff --git a/model_server/session.py b/model_server/session.py
index 21e66332..17958bb9 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -5,6 +5,7 @@ from pathlib import Path
 from time import strftime, localtime
 
 from conf.server import paths
+from model_server.model import Model
 from model_server.share import SharedImageDirectory
 from model_server.workflow import WorkflowRunRecord
 
@@ -34,7 +35,6 @@ class Session(object):
         self.manifest_json = self.where_records / f'{self.session_id}-manifest.json'
         open(self.manifest_json, 'w').close() # instantiate empty json file
 
-
     @staticmethod
     def create_session_id(look_where: Path) -> str:
         """
@@ -61,6 +61,21 @@ class Session(object):
         with open(self.manifest_json, 'w+') as fh:
             json.dump(record.dict(), fh)
 
+    def load_model(self, model_id: str) -> bool:
+        """
+        Load an instance of first model class that matches model_id string
+        :param model_id:
+        :return: True if model successfully loaded, False if not
+        """
+        for mc in Model.__subclasses__():
+            if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
+                mi = mc()
+                assert mi.loaded
+                self.models.append(mi)
+                return True
+        return False
+
+
     def restart(self):
         self.__init__()
 
diff --git a/model_server/workflow.py b/model_server/workflow.py
index 5f1a4a92..888b3374 100644
--- a/model_server/workflow.py
+++ b/model_server/workflow.py
@@ -30,6 +30,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
     assert (img.shape_dict['T'] == 1)
 
     # run model inference
+    # TODO: call this async / await and report out infer status to optional callback
     ch = kwargs.get('channel')
     outdata, messages = model.infer(img, channel=ch)
     dt_inf = time() - t0
diff --git a/tests/test_api.py b/tests/test_api.py
index 2bdd2451..f7c778c0 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -1,7 +1,9 @@
 from multiprocessing import Process
 import requests
 import unittest
-import uvicorn
+
+from conf.testing import czifile, output_path
+from model_server.model import DummyImageToImageModel
 
 class TestApiFromAutomatedClient(unittest.TestCase):
     def setUp(self) -> None:
@@ -25,3 +27,25 @@ class TestApiFromAutomatedClient(unittest.TestCase):
         resp = requests.get(self.uri, )
         self.assertEqual(resp.status_code, 200)
 
+    def test_list_empty_loaded_models(self):
+        resp = requests.get(self.uri + 'models')
+        print(resp.content)
+        self.assertEqual(resp.status_code, 200)
+
+    def test_load_model(self):
+        resp = requests.get(self.uri + 'load_')
+
+    def test_i2i_inference_errors_model_not_sound(self):
+        model_id = 'not_a_real_model'
+        resp = requests.post(self.uri + f'i2i/infer/{model_id}')
+        self.assertEqual(resp.status_code, 404)
+
+    def test_i2i_dummy_inference_by_api(self):
+        model = DummyImageToImageModel()
+        model_id = model.model_id
+        resp = requests.post(
+            self.uri + f'/i2i/infer/{model_id}',
+            str(czifile['path']),
+        )
+        print(resp)
+        self.assertEqual(resp.status_code, 200)
\ No newline at end of file
diff --git a/tests/test_model.py b/tests/test_model.py
index 1920ea3b..aff9db86 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -1,16 +1,30 @@
 import unittest
 from conf.testing import czifile
 from model_server.image import CziImageFileAccessor
-from model_server.model import DummyImageToImageModel
+from model_server.model import DummyImageToImageModel, CouldNotLoadModelError
 
 class TestCziImageFileAccess(unittest.TestCase):
     def setUp(self) -> None:
         self.cf = CziImageFileAccessor(czifile['path'])
 
+    def test_instantiate_model(self):
+        model = DummyImageToImageModel()
+        self.assertTrue(model.loaded)
+
     def test_instantiate_model_with_nondefault_kwarg(self):
         model = DummyImageToImageModel(autoload=False)
         self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')
 
+    def test_raise_error_if_cannot_load_model(self):
+        class UnloadableDummyImageToImageModel(DummyImageToImageModel):
+            def load(self):
+                return False
+
+        self.assertRaises(
+            CouldNotLoadModelError,
+            mi=UnloadableDummyImageToImageModel,
+        )
+
     def test_czifile_is_correct_shape(self):
         model = DummyImageToImageModel()
         img, _ = model.infer(self.cf, channel=1)
-- 
GitLab