From 2291d080b6c269e953464d367b180ae4ae2f8856 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 31 Aug 2023 16:12:20 +0200
Subject: [PATCH] Inference API now works again with dummy model

---
 api.py                  | 26 ++++++++++++++++++++++----
 model_server/session.py | 13 +++++++++----
 tests/test_api.py       | 23 +++++++++--------------
 tests/test_model.py     |  7 +------
 4 files changed, 41 insertions(+), 28 deletions(-)

diff --git a/api.py b/api.py
index ad40f5ef..1c168b44 100644
--- a/api.py
+++ b/api.py
@@ -24,17 +24,35 @@ def list_active_models():
 
 @app.put('/models/dummy/load/')
 def load_dummy_model() -> dict:
-    return session.load_model(DummyImageToImageModel)
+    return {'model_id': session.load_model(DummyImageToImageModel)}
 
 @app.put('/models/ilastik/pixel_classification/load/')
 def load_ilastik_pixel_classification_model(params: str) -> dict:
-    return session.load_model(IlastikPixelClassifierModel, params)
+    return {'model_id': session.load_model(IlastikPixelClassifierModel, params)}
 
 @app.put('/models/ilastik/object_classification/load/')
 def load_ilastik_object_classification_model(params: str) -> dict:
-    return session.load_model(IlastikObjectClassifierModel, params)
+    return {'model_id': session.load_model(IlastikObjectClassifierModel, params)}
 
-@app.put('/i2i/infer/')
+# @app.put('/models/ilastik/pixel_classification/load/')
+# def infer_ilastik_pixel_classification_from_file(input_filename: str, channel: int = None) -> dict:
+#     inpath = session.inbound.path / input_filename
+#     if not inpath.exists():
+#         raise HTTPException(
+#             status_code=404,
+#             detail=f'Could not find file:\n{inpath}'
+#         )
+#
+#     record = infer_image_to_image(
+#         inpath,
+#         session.models[model_id]['object'],
+#         session.outbound.path,
+#         channel=channel,
+#     )
+#     session.record_workflow_run(record)
+#     return record
+
+@app.put('/infer/from_image_file')
 def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
     if model_id not in session.describe_loaded_models().keys():
         raise HTTPException(
diff --git a/model_server/session.py b/model_server/session.py
index 6db22003..1a9b38b9 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -74,15 +74,20 @@ class Session(object):
         mi = ModelClass(params=params)
         assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
         ii = 0
-        def mid(i): return f'{ModelClass.__name__}_{ii:02d}'
+
+        def mid(i):
+            return f'{ModelClass.__name__}_{i:02d}'
+
         while mid(ii) in self.models.keys():
             ii += 1
-        self.models[mid(ii)] = {
+
+        key = mid(ii)
+        self.models[key] = {
             'object': mi,
             'params': params
         }
-        self.log_event(f'Loaded model {mid}')
-        return self.describe_loaded_models()
+        self.log_event(f'Loaded model {key}')
+        return key
 
     def describe_loaded_models(self) -> dict:
         return {
diff --git a/tests/test_api.py b/tests/test_api.py
index d4f7d266..c61e4c2b 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -47,29 +47,24 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         self.assertEqual(resp.content, b'{}')
 
     def test_load_dummy_model(self):
-        model_key = DummyImageToImageModel.__name__ + '_00'
+        # model_key = DummyImageToImageModel.__name__ + '_00'
         resp_load = requests.put(
             self.uri + f'models/dummy/load',
         )
+        model_id = resp_load.json()['model_id']
         self.assertEqual(resp_load.status_code, 200, resp_load.json())
         resp_list = requests.get(self.uri + 'models')
         self.assertEqual(resp_list.status_code, 200)
         rj = resp_list.json()
-        self.assertEqual(rj[model_key]['class'], 'DummyImageToImageModel')
+        self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
+        return model_id
 
-    def test_respond_with_error_when_invalid_model_loaded(self):
-        model_id = 'not_a_real_model'
-        resp = requests.put(
-            self.uri + f'models/load',
-            params={'model_id': model_id}
-        )
-        self.assertEqual(resp.status_code, 404)
-        print(resp.content)
 
     def test_respond_with_error_when_invalid_filepath_requested(self):
-        # model_id = self.test_load_dummy_model()
+        model_id = self.test_load_dummy_model()
+
         resp = requests.put(
-            self.uri + f'i2i/infer/',
+            self.uri + f'infer/from_image_file',
             params={
                 'model_id': model_id,
                 'input_filename': 'not_a_real_file.name'
@@ -81,7 +76,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
     def test_i2i_inference_errors_when_model_not_found(self):
         model_id = 'not_a_real_model'
         resp = requests.put(
-            self.uri + f'i2i/infer/',
+            self.uri + f'infer/from_image_file',
             params={
                 'model_id': model_id,
                 'input_filename': 'not_a_real_file.name'
@@ -93,7 +88,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         model_id = self.test_load_dummy_model()
         self.copy_input_file_to_server()
         resp_infer = requests.put(
-            self.uri + f'i2i/infer/',
+            self.uri + f'infer/from_image_file',
             params={
                 'model_id': model_id,
                 'input_filename': czifile['filename'],
diff --git a/tests/test_model.py b/tests/test_model.py
index 911b9f74..ca0324a9 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -46,9 +46,4 @@ class TestCziImageFileAccess(unittest.TestCase):
             img.data[0, 0],
             0,
             'First pixel is not black as expected'
-        )
-
-    def test_find_subclasses_recursively(self):
-        sc = DummyImageToImageModel
-        scs = Model.get_all_subclasses()
-        self.assertIn(DummyImageToImageModel, scs)
\ No newline at end of file
+        )
\ No newline at end of file
-- 
GitLab