Skip to content
Snippets Groups Projects
Commit 3eb38020 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Harmonized return value of model inference method

parent 81e071ee
No related branches found
No related tags found
No related merge requests found
......@@ -71,7 +71,7 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel):
[2, 1, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=xycz)
return InMemoryDataAccessor(data=xycz), {'success': True}
......@@ -47,9 +47,8 @@ class Model(ABC):
pass
@abstractmethod
def infer(self, img: GenericImageDataAccessor) -> (np.ndarray, dict): # return json describing inference result
if self.autoload:
self.load()
def infer(self, img: GenericImageDataAccessor) -> (object, dict): # return json describing inference result
pass
def reload(self):
self.load()
......@@ -61,8 +60,8 @@ class ImageToImageModel(Model):
"""
@abstractmethod
def infer(self, img) -> (GenericImageDataAccessor, dict):
super().infer(img)
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
pass
class DummyImageToImageModel(ImageToImageModel):
......@@ -71,13 +70,13 @@ class DummyImageToImageModel(ImageToImageModel):
def load(self):
return True
def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor:
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
super().infer(img)
w = img.shape_dict['X']
h = img.shape_dict['Y']
result = np.zeros([h, w], dtype='uint8')
result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255
return InMemoryDataAccessor(data=result)
return InMemoryDataAccessor(data=result), {'success': True}
class Error(Exception):
pass
......
......@@ -28,7 +28,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
# run model inference
# TODO: call this async / await and report out infer status to optional callback
outdata = model.infer(img)
outdata, _ = model.infer(img)
dt_inf = time() - t0
# TODO: assert outdata format
......
......@@ -30,7 +30,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))
with self.assertRaises(AttributeError):
pxmap = model.infer(input_img)
pxmap , _= model.infer(input_img)
def test_run_pixel_classifier_on_random_data(self):
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
......@@ -39,7 +39,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))
pxmap = model.infer(input_img)
pxmap, _ = model.infer(input_img)
self.assertEqual(pxmap.shape, (w, h, 2, 1))
def test_run_pixel_classifier(self):
......@@ -53,7 +53,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.assertEqual(mono_image.shape_dict['C'], 1)
self.assertEqual(mono_image.shape_dict['Z'], 1)
pxmap = model.infer(mono_image)
pxmap, _ = model.infer(mono_image)
self.assertEqual(pxmap.shape[0:2], cf.shape[0:2])
self.assertEqual(pxmap.shape_dict['C'], 2)
......
......@@ -25,7 +25,7 @@ class TestCziImageFileAccess(unittest.TestCase):
def test_czifile_is_correct_shape(self):
model = DummyImageToImageModel()
img = model.infer(self.cf)
img, _ = model.infer(self.cf)
w = czifile['w']
h = czifile['h']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment