From e2e7ea215281e9dfd9f245671f715f60daac2037 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 28 Aug 2023 14:23:04 +0200
Subject: [PATCH] Successfully tested calling inference over API

---
 api.py                  | 17 +++++++++++------
 conf/testing.py         |  5 +++--
 model_server/image.py   | 10 +++++-----
 model_server/session.py |  2 +-
 model_server/share.py   |  3 ++-
 tests/test_api.py       | 41 ++++++++++++++++++++++++++++++-----------
 tests/test_session.py   | 13 +++++++++----
 7 files changed, 61 insertions(+), 30 deletions(-)

diff --git a/api.py b/api.py
index 5a5f25aa..f5d7976f 100644
--- a/api.py
+++ b/api.py
@@ -30,19 +30,24 @@ def load_model(model_id: str) -> dict:
     session.load_model(model_id)
     return session.describe_models()
 
-@app.post('/i2i/infer/{model_id}') # image file in, image file out
-def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
-    if model_id not in session.models.keys():
+@app.put('/i2i/infer/{model_id}') # image file in, image file out
+def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
+    if model_id not in session.describe_models().keys():
         raise HTTPException(
             status_code=409,
             detail=f'Model {model_id} has not been loaded'
         )
+    inpath = session.inbound.path / input_filename
+    if not inpath.exists():
+        raise HTTPException(
+            status_code=404,
+            detail=f'Could not find file:\n{inpath}'
+        )
 
-    # TODO: try block workflow, catch and redirect HTTP
     record = infer_image_to_image(
-        session.inbound / imgf,
+        inpath,
         session.models[model_id],
-        session.outbound,
+        session.outbound.path,
         channel=channel,
         # TODO: optional callback for status reporting
     )
diff --git a/conf/testing.py b/conf/testing.py
index 753153a3..94a42aa8 100644
--- a/conf/testing.py
+++ b/conf/testing.py
@@ -1,9 +1,10 @@
 from pathlib import Path
 
 root = Path('c:/Users/rhodes/projects/proj0015-model-server/resources')
-
+filename = 'Selection--W0000--P0001-T0001.czi'
 czifile = {
-    'path': root / 'testdata' / 'Selection--W0000--P0001-T0001.czi',
+    'filename': filename,
+    'path': root / 'testdata' / filename,
     'w': 1024,
     'h': 1024,
     'c': 4,
diff --git a/model_server/image.py b/model_server/image.py
index ec9dd677..bd6b1abc 100644
--- a/model_server/image.py
+++ b/model_server/image.py
@@ -54,17 +54,17 @@ def generate_file_accessor(fpath):
     else:
         raise FileAccessorError(f'Could not match a file accessor with {fpath}')
 
-def Error(Exception):
+class Error(Exception):
     pass
 
-def FileAccessorError(Error):
+class FileAccessorError(Error):
     pass
 
-def FileNotFoundError(Error):
+class FileNotFoundError(Error):
     pass
 
-def FileShapeError(Error):
+class FileShapeError(Error):
     pass
 
-def FileWriteError(Error):
+class FileWriteError(Error):
     pass
\ No newline at end of file
diff --git a/model_server/session.py b/model_server/session.py
index 66a3ac11..f3b65d62 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -72,7 +72,7 @@ class Session(object):
             if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
                 mi = mc()
                 assert mi.loaded
-                self.models['model_id'] = mi
+                self.models[model_id] = mi
                 return True
         raise CouldNotFindModelError(
             f'Could not find {model_id} in:\n{models}',
diff --git a/model_server/share.py b/model_server/share.py
index e817febc..e270dea6 100644
--- a/model_server/share.py
+++ b/model_server/share.py
@@ -1,11 +1,12 @@
 import os
+import pathlib
 
 def validate_directory_exists(path):
     return os.path.exists(path)
 
 class SharedImageDirectory(object):
 
-    def __init__(self, path):
+    def __init__(self, path: pathlib.Path):
         self.path = path
         self.is_memory_mapped = False
 
diff --git a/tests/test_api.py b/tests/test_api.py
index 17afe2e3..ecc0296c 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -1,8 +1,9 @@
 from multiprocessing import Process
+
 import requests
 import unittest
 
-from conf.testing import czifile, output_path
+from conf.testing import czifile
 from model_server.model import DummyImageToImageModel
 
 class TestApiFromAutomatedClient(unittest.TestCase):
@@ -20,6 +21,19 @@ class TestApiFromAutomatedClient(unittest.TestCase):
         self.uri = f'http://{host}:{port}/'
         self.server_process.start()
 
+    @staticmethod
+    def copy_input_file_to_server():
+        import pathlib
+        from shutil import copyfile
+        from conf.server import paths
+
+        outpath = pathlib.Path(paths['images']['inbound'] / czifile['filename'])
+
+        copyfile(
+            czifile['path'],
+            outpath
+        )
+
     def tearDown(self) -> None:
         self.server_process.terminate()
 
@@ -38,19 +52,24 @@ class TestApiFromAutomatedClient(unittest.TestCase):
         self.assertEqual(resp_load.status_code, 200)
         resp_list = requests.get(self.uri + 'models')
         self.assertEqual(resp_list.status_code, 200)
-        self.assertEqual(resp_list.content, b'{"model_id":"DummyImageToImageModel"}')
+        self.assertEqual(resp_list.content, b'{"dummy_make_white_square":"DummyImageToImageModel"}')
 
-    def test_i2i_inference_errors_model_not_sound(self):
+    def test_i2i_inference_errors_model_not_found(self):
         model_id = 'not_a_real_model'
-        resp = requests.post(self.uri + f'i2i/infer/{model_id}')
-        self.assertEqual(resp.status_code, 404)
+        resp = requests.put(
+            self.uri + f'i2i/infer/{model_id}',
+            params={'input_filename': 'not_a_real_file.name'}
+        )
+        print(resp.content)
+        self.assertEqual(resp.status_code, 409)
 
     def test_i2i_dummy_inference_by_api(self):
         model = DummyImageToImageModel()
-        model_id = model.model_id
-        resp = requests.post(
-            self.uri + f'/i2i/infer/{model_id}',
-            str(czifile['path']),
+        resp_load = requests.get(self.uri + f'models/{model.model_id}/load')
+        self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}')
+        self.copy_input_file_to_server()
+        resp_infer = requests.put(
+            self.uri + f'i2i/infer/{model.model_id}',
+            params={'input_filename': czifile['filename']},
         )
-        print(resp)
-        self.assertEqual(resp.status_code, 200)
\ No newline at end of file
+        self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model.model_id}')
\ No newline at end of file
diff --git a/tests/test_session.py b/tests/test_session.py
index 5d10b07c..2538d3d2 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -37,9 +37,14 @@ class TestGetSessionObject(unittest.TestCase):
             do = json.load(fh)
         self.assertEqual(di.dict(), do, 'Manifest record is not correct')
 
-    def test_session_load_model(self):
+    def test_session_loads_model(self):
         sesh = Session()
-        success = sesh.load_model(DummyImageToImageModel.model_id)
+        model_id = DummyImageToImageModel.model_id
+        success = sesh.load_model(model_id)
         self.assertTrue(success)
-        self.assertTrue('model_id' in sesh.models.keys())
-        self.assertEqual(sesh.models['model_id'].__class__, DummyImageToImageModel)
\ No newline at end of file
+        loaded_models = sesh.describe_models()
+        self.assertTrue(model_id in loaded_models.keys())
+        self.assertEqual(
+            loaded_models[model_id],
+            DummyImageToImageModel.__name__
+        )
\ No newline at end of file
-- 
GitLab