From 81e071eed20e40380af3dae35406e6b4826ad662 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 30 Aug 2023 11:19:50 +0200 Subject: [PATCH] Implemented autoload flag in model instantiation --- model_server/model.py | 4 +++- model_server/session.py | 2 +- tests/test_api.py | 2 +- tests/test_ilastik.py | 12 +++++++++++- tests/test_model.py | 2 +- tests/test_workflow.py | 2 +- 6 files changed, 18 insertions(+), 6 deletions(-) diff --git a/model_server/model.py b/model_server/model.py index 8c07d2d8..38b772ff 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -18,10 +18,12 @@ class Model(ABC): """ self.autoload = autoload self.params = params + self.loaded = False + if not autoload: + return None if self.load(): self.loaded = True else: - self.loaded = False raise CouldNotLoadModelError() return None diff --git a/model_server/session.py b/model_server/session.py index 8831b606..bd9a5d9d 100644 --- a/model_server/session.py +++ b/model_server/session.py @@ -73,7 +73,7 @@ class Session(object): for mc in models: if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id: try: - mi = mc(params) + mi = mc(params=params) assert mi.loaded except: raise CouldNotInstantiateModelError() diff --git a/tests/test_api.py b/tests/test_api.py index f2805eef..0b2e844b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -93,5 +93,5 @@ class TestApiFromAutomatedClient(unittest.TestCase): 'channel': 2, }, ) - self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model_id}', resp_infer.content.decode()) + self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index ac387a41..9a596bce 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -21,6 +21,17 @@ class TestIlastikPixelClassification(unittest.TestCase): with self.assertRaises(io.UnsupportedOperation): faulthandler.enable(file=sys.stdout) + + def test_raise_error_if_autoload_disabled(self): + model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}, autoload=False) + w = 512 + h = 256 + + input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1)) + + with self.assertRaises(AttributeError): + pxmap = model.infer(input_img) + def test_run_pixel_classifier_on_random_data(self): model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) w = 512 @@ -47,7 +58,6 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(pxmap.shape[0:2], cf.shape[0:2]) self.assertEqual(pxmap.shape_dict['C'], 2) self.assertEqual(pxmap.shape_dict['Z'], 1) - print(pxmap.shape_dict) self.assertTrue( write_accessor_data_to_file( output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', diff --git a/tests/test_model.py b/tests/test_model.py index 9f4ab388..b2654af6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -8,7 +8,7 @@ class TestCziImageFileAccess(unittest.TestCase): self.cf = CziImageFileAccessor(czifile['path']) def test_instantiate_model(self): - model = DummyImageToImageModel() + model = DummyImageToImageModel(params=None) self.assertTrue(model.loaded) def test_instantiate_model_with_nondefault_kwarg(self): diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 83935d3d..6bc8d56d 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -20,7 +20,7 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual( img.shape, - (h, w, 1, 1), + (h, w), 'Inferred image is not the expected shape' ) -- GitLab