Skip to content
Snippets Groups Projects
Commit 81e071ee authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Implemented autoload flag in model instantiation

parent 7f82c049
No related branches found
No related tags found
No related merge requests found
......@@ -18,10 +18,12 @@ class Model(ABC):
"""
self.autoload = autoload
self.params = params
self.loaded = False
if not autoload:
return None
if self.load():
self.loaded = True
else:
self.loaded = False
raise CouldNotLoadModelError()
return None
......
......@@ -73,7 +73,7 @@ class Session(object):
for mc in models:
if hasattr(mc, 'model_id') and getattr(mc, 'model_id') == model_id:
try:
mi = mc(params)
mi = mc(params=params)
assert mi.loaded
except:
raise CouldNotInstantiateModelError()
......
......@@ -93,5 +93,5 @@ class TestApiFromAutomatedClient(unittest.TestCase):
'channel': 2,
},
)
self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model_id}', resp_infer.content.decode())
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
......@@ -21,6 +21,17 @@ class TestIlastikPixelClassification(unittest.TestCase):
with self.assertRaises(io.UnsupportedOperation):
faulthandler.enable(file=sys.stdout)
def test_raise_error_if_autoload_disabled(self):
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}, autoload=False)
w = 512
h = 256
input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))
with self.assertRaises(AttributeError):
pxmap = model.infer(input_img)
def test_run_pixel_classifier_on_random_data(self):
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
w = 512
......@@ -47,7 +58,6 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertEqual(pxmap.shape[0:2], cf.shape[0:2])
self.assertEqual(pxmap.shape_dict['C'], 2)
self.assertEqual(pxmap.shape_dict['Z'], 1)
print(pxmap.shape_dict)
self.assertTrue(
write_accessor_data_to_file(
output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
......
......@@ -8,7 +8,7 @@ class TestCziImageFileAccess(unittest.TestCase):
self.cf = CziImageFileAccessor(czifile['path'])
def test_instantiate_model(self):
model = DummyImageToImageModel()
model = DummyImageToImageModel(params=None)
self.assertTrue(model.loaded)
def test_instantiate_model_with_nondefault_kwarg(self):
......
......@@ -20,7 +20,7 @@ class TestGetSessionObject(unittest.TestCase):
self.assertEqual(
img.shape,
(h, w, 1, 1),
(h, w),
'Inferred image is not the expected shape'
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment