From 81e071eed20e40380af3dae35406e6b4826ad662 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 30 Aug 2023 11:19:50 +0200
Subject: [PATCH] Implemented autoload flag in model instantiation

---
 model_server/model.py   |  4 +++-
 model_server/session.py |  2 +-
 tests/test_api.py       |  2 +-
 tests/test_ilastik.py   | 12 +++++++++++-
 tests/test_model.py     |  2 +-
 tests/test_workflow.py  |  2 +-
 6 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/model_server/model.py b/model_server/model.py
index 8c07d2d8..38b772ff 100644
--- a/model_server/model.py
+++ b/model_server/model.py
@@ -18,10 +18,12 @@ class Model(ABC):
         """
         self.autoload = autoload
         self.params = params
+        self.loaded = False
+        if not autoload:
+            return None
         if self.load():
             self.loaded = True
         else:
-            self.loaded = False
             raise CouldNotLoadModelError()
         return None
 
diff --git a/model_server/session.py b/model_server/session.py
index 8831b606..bd9a5d9d 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -73,7 +73,7 @@ class Session(object):
         for mc in models:
             if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
                 try:
-                    mi = mc(params)
+                    mi = mc(params=params)
                     assert mi.loaded
                 except:
                     raise CouldNotInstantiateModelError()
diff --git a/tests/test_api.py b/tests/test_api.py
index f2805eef..0b2e844b 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -93,5 +93,5 @@ class TestApiFromAutomatedClient(unittest.TestCase):
                 'channel': 2,
             },
         )
-        self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model_id}', resp_infer.content.decode())
+        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
 
diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py
index ac387a41..9a596bce 100644
--- a/tests/test_ilastik.py
+++ b/tests/test_ilastik.py
@@ -21,6 +21,17 @@ class TestIlastikPixelClassification(unittest.TestCase):
         with self.assertRaises(io.UnsupportedOperation):
             faulthandler.enable(file=sys.stdout)
 
+
+    def test_raise_error_if_autoload_disabled(self):
+        model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}, autoload=False)
+        w = 512
+        h = 256
+
+        input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))
+
+        with self.assertRaises(AttributeError):
+            pxmap = model.infer(input_img)
+
     def test_run_pixel_classifier_on_random_data(self):
         model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
         w = 512
@@ -47,7 +58,6 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.assertEqual(pxmap.shape[0:2], cf.shape[0:2])
         self.assertEqual(pxmap.shape_dict['C'], 2)
         self.assertEqual(pxmap.shape_dict['Z'], 1)
-        print(pxmap.shape_dict)
         self.assertTrue(
             write_accessor_data_to_file(
                 output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
diff --git a/tests/test_model.py b/tests/test_model.py
index 9f4ab388..b2654af6 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -8,7 +8,7 @@ class TestCziImageFileAccess(unittest.TestCase):
         self.cf = CziImageFileAccessor(czifile['path'])
 
     def test_instantiate_model(self):
-        model = DummyImageToImageModel()
+        model = DummyImageToImageModel(params=None)
         self.assertTrue(model.loaded)
 
     def test_instantiate_model_with_nondefault_kwarg(self):
diff --git a/tests/test_workflow.py b/tests/test_workflow.py
index 83935d3d..6bc8d56d 100644
--- a/tests/test_workflow.py
+++ b/tests/test_workflow.py
@@ -20,7 +20,7 @@ class TestGetSessionObject(unittest.TestCase):
 
         self.assertEqual(
             img.shape,
-            (h, w, 1, 1),
+            (h, w),
             'Inferred image is not the expected shape'
         )
 
-- 
GitLab