Skip to content
Snippets Groups Projects
Commit e3162d6b authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Updated test for channel selection, and conforming image data to XYCZ order

parent 2095889d
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,7 @@ def load_model(model_id: str, params: Dict[str, str] = None) -> dict:
session.load_model(model_id, params=params)
return session.describe_loaded_models()
@app.put('/i2i/infer/{model_id}') # image file in, image file out
@app.put('/i2i/infer/') # image file in, image file out
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
if model_id not in session.describe_loaded_models().keys():
raise HTTPException(
......
......@@ -8,7 +8,8 @@ from ilastik.workflows.objectClassification.objectClassificationWorkflow import
import numpy as np
from model_server.model import GenericImageDataAccessor, ImageToImageModel, ParameterExpectedError
from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.model import ImageToImageModel, ParameterExpectedError
class IlastikImageToImageModel(ImageToImageModel):
......@@ -64,11 +65,11 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel):
def infer(self, input_img: GenericImageDataAccessor, channel=None) -> (np.ndarray, dict):
dsi = [
{
'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img.data),
'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img.data[0:3]),
}
]
pxmap = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)
return pxmap
pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)
return InMemoryDataAccessor(data=pxmaps[0])
......@@ -14,7 +14,7 @@ class GenericImageDataAccessor(ABC):
def __init__(self):
"""
Abstract base class that exposes an interfaces for image data, irrespective of whether it is instantiated
from file I/O or other means.
from file I/O or other means. Enforces X, Y, C, Z dimensions in that order.
"""
pass
......@@ -22,23 +22,34 @@ class GenericImageDataAccessor(ABC):
def chroma(self):
return self.shape_dict['C']
@staticmethod
def conform_data(data):
if len(data.shape) > 4:
raise DataShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z: {data.shape}')
ones = [1 for i in range(0, 4 - len(data.shape))]
return data.reshape(*data.shape, *ones)
def is_3d(self):
return True if self.shape_dict['Z'] > 1 else False
def get_one_channel_data (self, channel: int):
return InMemoryDataAccessor(self.data[:, :, channel, :])
return InMemoryDataAccessor(self.data[:, :, int(channel), :])
@property
def data(self): # XYCZ enforced
def data(self):
return self._data
@property
def shape(self):
return self._data.shape
@property
def shape_dict(self):
return dict(zip(('X', 'Y', 'C', 'Z'), self.data.shape))
class InMemoryDataAccessor(GenericImageDataAccessor):
def __init__(self, data):
self._data = data
self._data = self.conform_data(data)
class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded from a file
def __init__(self, fpath: Path):
......@@ -67,7 +78,7 @@ class CziImageFileAccessor(GenericImageFileAccessor):
sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes}
if sd['S'] > 1 or sd['T'] > 1:
raise FileShapeError(f'Cannot handle image with multiple positions or time points: {sd}')
raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}')
idx = {k: sd[k] for k in ['X', 'Y', 'C', 'Z']}
xycz = np.moveaxis(
......@@ -76,10 +87,11 @@ class CziImageFileAccessor(GenericImageFileAccessor):
[0, 1, 2, 3]
)
try:
self._data = xycz.reshape(xycz.shape[0:4])
except Exception:
raise FileShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z')
# try:
# self._data = xycz.reshape(xycz.shape[0:4])
# except Exception:
# raise FileShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z')
self._data = self.conform_data(xycz.reshape(xycz.shape[0:4]))
def __del__(self):
self.czifile.close()
......@@ -108,7 +120,7 @@ class FileAccessorError(Error):
class FileNotFoundError(Error):
pass
class FileShapeError(Error):
class DataShapeError(Error):
pass
class FileWriteError(Error):
......
......@@ -4,7 +4,7 @@ import os
import numpy as np
from model_server.image import GenericImageDataAccessor
from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor
class Model(ABC):
......@@ -59,7 +59,7 @@ class ImageToImageModel(Model):
"""
@abstractmethod
def infer(self, img) -> (np.ndarray, dict):
def infer(self, img) -> (GenericImageDataAccessor, dict):
super().infer(img)
class DummyImageToImageModel(ImageToImageModel):
......@@ -69,14 +69,13 @@ class DummyImageToImageModel(ImageToImageModel):
def load(self):
return True
def infer(self, img: GenericImageDataAccessor) -> (np.ndarray, dict):
def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor:
super().infer(img)
w = img.shape_dict['X']
h = img.shape_dict['Y']
result = np.zeros([h, w], dtype='uint8')
result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255
return (result, {'success': True})
return InMemoryDataAccessor(data=result)
class Error(Exception):
pass
......
......@@ -28,7 +28,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
# run model inference
# TODO: call this async / await and report out infer status to optional callback
outdata, messages = model.infer(img)
outdata = model.infer(img)
dt_inf = time() - t0
# TODO: assert outdata format
......@@ -49,7 +49,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
model_id=model.model_id,
input_filepath=str(fpi),
output_filepath=str(outpath),
success=messages['success'],
success=True,
timer_results=timer_results
)
......
......@@ -62,8 +62,11 @@ class TestApiFromAutomatedClient(unittest.TestCase):
def test_i2i_inference_errors_model_not_found(self):
model_id = 'not_a_real_model'
resp = requests.put(
self.uri + f'i2i/infer/{model_id}',
params={'input_filename': 'not_a_real_file.name'}
self.uri + f'i2i/infer/',
params={
'model_id': model_id,
'input_filename': 'not_a_real_file.name'
}
)
print(resp.content)
self.assertEqual(resp.status_code, 409)
......@@ -77,7 +80,11 @@ class TestApiFromAutomatedClient(unittest.TestCase):
self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}')
self.copy_input_file_to_server()
resp_infer = requests.put(
self.uri + f'i2i/infer/{model.model_id}',
params={'input_filename': czifile['filename']},
self.uri + f'i2i/infer/',
params={
'model_id': model.model_id,
'input_filename': czifile['filename'],
'channel': 2,
},
)
self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model.model_id}')
\ No newline at end of file
import unittest
from conf.testing import czifile, ilastik
from model_server.image import CziImageFileAccessor
from conf.testing import czifile, ilastik, output_path
from model_server.image import CziImageFileAccessor, write_accessor_data_to_file
from model_server.ilastik import IlastikPixelClassifierModel
class TestIlastikPixelClassification(unittest.TestCase):
......@@ -18,8 +18,21 @@ class TestIlastikPixelClassification(unittest.TestCase):
with self.assertRaises(io.UnsupportedOperation):
faulthandler.enable(file=sys.stdout)
def test_instantiate_pixel_classifier(self):
def test_run_pixel_classifier(self):
channel = 2
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
cf = CziImageFileAccessor(czifile['path'])
model.infer(cf.get_one_channel_data(2))
pxmap = model.infer(cf.get_one_channel_data(channel))
print(pxmap.shape_dict)
# self.assertEqual(pxmap.shape[0:2], cf.shape[0:2])
# self.assertEqual(pxmap.shape_dict['C'], 2)
# self.assertEqual(pxmap.shape_dict['Z'], 1)
#
# self.assertTrue(
# write_accessor_data_to_file(
# output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
# pxmap
# )
# )
\ No newline at end of file
import unittest
import numpy as np
from conf.testing import czifile, output_path
from model_server.image import CziImageFileAccessor, write_accessor_data_to_file
from model_server.image import CziImageFileAccessor, DataShapeError, InMemoryDataAccessor, write_accessor_data_to_file
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
......@@ -13,6 +16,8 @@ class TestCziImageFileAccess(unittest.TestCase):
self.assertEqual(cf.chroma, czifile['c'])
self.assertFalse(cf.is_3d())
self.assertEqual(len(cf.data.shape), 4)
self.assertEqual(cf.shape[0], czifile['h'])
self.assertEqual(cf.shape[1], czifile['w'])
def test_write_single_channel_tif(self):
ch = 2
......@@ -26,3 +31,16 @@ class TestCziImageFileAccess(unittest.TestCase):
)
self.assertEqual(cf.data.shape[0:2], mono.data.shape[0:2])
self.assertEqual(cf.data.shape[3], mono.data.shape[2])
def test_conform_data_shorter_than_xycz(self):
data = np.random.rand(256, 512)
acc = InMemoryDataAccessor(data)
self.assertEqual(
acc.shape_dict,
{'X': 256, 'Y': 512, 'C': 1, 'Z': 1}
)
def test_conform_data_longer_than_xycz(self):
data = np.random.rand(256, 512, 12, 8, 3)
with self.assertRaises(DataShapeError):
acc = InMemoryDataAccessor(data)
\ No newline at end of file
......@@ -25,25 +25,25 @@ class TestCziImageFileAccess(unittest.TestCase):
def test_czifile_is_correct_shape(self):
model = DummyImageToImageModel()
img, _ = model.infer(self.cf)
img = model.infer(self.cf)
w = czifile['w']
h = czifile['h']
self.assertEqual(
img.shape,
(h, w),
(h, w, 1, 1),
'Inferred image is not the expected shape'
)
self.assertEqual(
img[int(w/2), int(h/2)],
img.data[int(w/2), int(h/2)],
255,
'Middle pixel is not white as expected'
)
self.assertEqual(
img[0, 0],
img.data[0, 0],
0,
'First pixel is not black as expected'
)
......
......@@ -10,7 +10,7 @@ class TestGetSessionObject(unittest.TestCase):
self.model = DummyImageToImageModel()
def test_single_session_instance(self):
result = infer_image_to_image(czifile['path'], self.model, output_path)
result = infer_image_to_image(czifile['path'], self.model, output_path, channel=2)
self.assertTrue(result.success)
import tifffile
......@@ -20,7 +20,7 @@ class TestGetSessionObject(unittest.TestCase):
self.assertEqual(
img.shape,
(h, w),
(h, w, 1, 1),
'Inferred image is not the expected shape'
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment