diff --git a/api.py b/api.py index 89e42d8fde48a22e8447964da59e32fe388fa5ee..1bf0f17d31e93c1a7c78d735f021b611532466b5 100644 --- a/api.py +++ b/api.py @@ -30,7 +30,7 @@ def load_model(model_id: str, params: Dict[str, str] = None) -> dict: session.load_model(model_id, params=params) return session.describe_loaded_models() -@app.put('/i2i/infer/{model_id}') # image file in, image file out +@app.put('/i2i/infer/') # image file in, image file out def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: if model_id not in session.describe_loaded_models().keys(): raise HTTPException( diff --git a/model_server/ilastik.py b/model_server/ilastik.py index a345925ff2bea707a283c0373621712a4172475c..7be2269c14123cb8fcaa5747316c32516d8f44a9 100644 --- a/model_server/ilastik.py +++ b/model_server/ilastik.py @@ -8,7 +8,8 @@ from ilastik.workflows.objectClassification.objectClassificationWorkflow import import numpy as np -from model_server.model import GenericImageDataAccessor, ImageToImageModel, ParameterExpectedError +from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor +from model_server.model import ImageToImageModel, ParameterExpectedError class IlastikImageToImageModel(ImageToImageModel): @@ -64,11 +65,11 @@ class IlastikPixelClassifierModel(IlastikImageToImageModel): def infer(self, input_img: GenericImageDataAccessor, channel=None) -> (np.ndarray, dict): dsi = [ { - 'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img.data), + 'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img.data[0:3]), } ] - pxmap = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) - return pxmap + pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) + return InMemoryDataAccessor(data=pxmaps[0]) diff --git a/model_server/image.py b/model_server/image.py index 2e314d1b42a183c1099c9685a915c1791de987be..4bea66a1ed2660d0a8384eff681f495a5c3e9846 100644 --- a/model_server/image.py +++ b/model_server/image.py @@ -14,7 +14,7 @@ class GenericImageDataAccessor(ABC): def __init__(self): """ Abstract base class that exposes an interfaces for image data, irrespective of whether it is instantiated - from file I/O or other means. + from file I/O or other means. Enforces X, Y, C, Z dimensions in that order. """ pass @@ -22,23 +22,34 @@ class GenericImageDataAccessor(ABC): def chroma(self): return self.shape_dict['C'] + @staticmethod + def conform_data(data): + if len(data.shape) > 4: + raise DataShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z: {data.shape}') + ones = [1 for i in range(0, 4 - len(data.shape))] + return data.reshape(*data.shape, *ones) + def is_3d(self): return True if self.shape_dict['Z'] > 1 else False def get_one_channel_data (self, channel: int): - return InMemoryDataAccessor(self.data[:, :, channel, :]) + return InMemoryDataAccessor(self.data[:, :, int(channel), :]) @property - def data(self): # XYCZ enforced + def data(self): return self._data + @property + def shape(self): + return self._data.shape + @property def shape_dict(self): return dict(zip(('X', 'Y', 'C', 'Z'), self.data.shape)) class InMemoryDataAccessor(GenericImageDataAccessor): def __init__(self, data): - self._data = data + self._data = self.conform_data(data) class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded from a file def __init__(self, fpath: Path): @@ -67,7 +78,7 @@ class CziImageFileAccessor(GenericImageFileAccessor): sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes} if sd['S'] > 1 or sd['T'] > 1: - raise FileShapeError(f'Cannot handle image with multiple positions or time points: {sd}') + raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}') idx = {k: sd[k] for k in ['X', 'Y', 'C', 'Z']} xycz = np.moveaxis( @@ -76,10 +87,11 @@ class CziImageFileAccessor(GenericImageFileAccessor): [0, 1, 2, 3] ) - try: - self._data = xycz.reshape(xycz.shape[0:4]) - except Exception: - raise FileShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z') + # try: + # self._data = xycz.reshape(xycz.shape[0:4]) + # except Exception: + # raise FileShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z') + self._data = self.conform_data(xycz.reshape(xycz.shape[0:4])) def __del__(self): self.czifile.close() @@ -108,7 +120,7 @@ class FileAccessorError(Error): class FileNotFoundError(Error): pass -class FileShapeError(Error): +class DataShapeError(Error): pass class FileWriteError(Error): diff --git a/model_server/model.py b/model_server/model.py index 672bb63d6a2e7cc5d52d667576bbbc42e41bea2c..8c07d2d8c0d71dd3caa724905d2ea40566b865d8 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -4,7 +4,7 @@ import os import numpy as np -from model_server.image import GenericImageDataAccessor +from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor class Model(ABC): @@ -59,7 +59,7 @@ class ImageToImageModel(Model): """ @abstractmethod - def infer(self, img) -> (np.ndarray, dict): + def infer(self, img) -> (GenericImageDataAccessor, dict): super().infer(img) class DummyImageToImageModel(ImageToImageModel): @@ -69,14 +69,13 @@ class DummyImageToImageModel(ImageToImageModel): def load(self): return True - def infer(self, img: GenericImageDataAccessor) -> (np.ndarray, dict): + def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor: super().infer(img) w = img.shape_dict['X'] h = img.shape_dict['Y'] result = np.zeros([h, w], dtype='uint8') result[floor(0.25 * h) : floor(0.75 * h), floor(0.25 * w) : floor(0.75 * w)] = 255 - return (result, {'success': True}) - + return InMemoryDataAccessor(data=result) class Error(Exception): pass diff --git a/model_server/workflow.py b/model_server/workflow.py index ece9b1b443db2d5f2e6449bf441cf47e006a8151..38803ed6ea391240ae5641423e55e25d311dc07d 100644 --- a/model_server/workflow.py +++ b/model_server/workflow.py @@ -28,7 +28,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict: # run model inference # TODO: call this async / await and report out infer status to optional callback - outdata, messages = model.infer(img) + outdata = model.infer(img) dt_inf = time() - t0 # TODO: assert outdata format @@ -49,7 +49,7 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict: model_id=model.model_id, input_filepath=str(fpi), output_filepath=str(outpath), - success=messages['success'], + success=True, timer_results=timer_results ) diff --git a/tests/test_api.py b/tests/test_api.py index e290d16350f6e89495cc36bd63cbe949610006a3..f5869bddc0e273481c604847ee55029ba7e591ce 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -62,8 +62,11 @@ class TestApiFromAutomatedClient(unittest.TestCase): def test_i2i_inference_errors_model_not_found(self): model_id = 'not_a_real_model' resp = requests.put( - self.uri + f'i2i/infer/{model_id}', - params={'input_filename': 'not_a_real_file.name'} + self.uri + f'i2i/infer/', + params={ + 'model_id': model_id, + 'input_filename': 'not_a_real_file.name' + } ) print(resp.content) self.assertEqual(resp.status_code, 409) @@ -77,7 +80,11 @@ class TestApiFromAutomatedClient(unittest.TestCase): self.assertEqual(resp_load.status_code, 200, f'Error loading {model.model_id}') self.copy_input_file_to_server() resp_infer = requests.put( - self.uri + f'i2i/infer/{model.model_id}', - params={'input_filename': czifile['filename']}, + self.uri + f'i2i/infer/', + params={ + 'model_id': model.model_id, + 'input_filename': czifile['filename'], + 'channel': 2, + }, ) self.assertEqual(resp_infer.status_code, 200, f'Error inferring from {model.model_id}') \ No newline at end of file diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index 6fd345634b9e3fa61141a436fb8dd3b088243f07..94a7cbbb3da47de8643c06fad22752a83e45d867 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -1,6 +1,6 @@ import unittest -from conf.testing import czifile, ilastik -from model_server.image import CziImageFileAccessor +from conf.testing import czifile, ilastik, output_path +from model_server.image import CziImageFileAccessor, write_accessor_data_to_file from model_server.ilastik import IlastikPixelClassifierModel class TestIlastikPixelClassification(unittest.TestCase): @@ -18,8 +18,21 @@ class TestIlastikPixelClassification(unittest.TestCase): with self.assertRaises(io.UnsupportedOperation): faulthandler.enable(file=sys.stdout) - def test_instantiate_pixel_classifier(self): + def test_run_pixel_classifier(self): + channel = 2 model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) cf = CziImageFileAccessor(czifile['path']) - model.infer(cf.get_one_channel_data(2)) + pxmap = model.infer(cf.get_one_channel_data(channel)) + print(pxmap.shape_dict) + + # self.assertEqual(pxmap.shape[0:2], cf.shape[0:2]) + # self.assertEqual(pxmap.shape_dict['C'], 2) + # self.assertEqual(pxmap.shape_dict['Z'], 1) + # + # self.assertTrue( + # write_accessor_data_to_file( + # output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', + # pxmap + # ) + # ) \ No newline at end of file diff --git a/tests/test_image.py b/tests/test_image.py index 8aef0eaf392c5dfa946df12910749f6141bff577..3ca3cb2e2c8d86f16e71bc5c7d29f3a0ad3fe7f8 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -1,6 +1,9 @@ import unittest + +import numpy as np + from conf.testing import czifile, output_path -from model_server.image import CziImageFileAccessor, write_accessor_data_to_file +from model_server.image import CziImageFileAccessor, DataShapeError, InMemoryDataAccessor, write_accessor_data_to_file class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: @@ -13,6 +16,8 @@ class TestCziImageFileAccess(unittest.TestCase): self.assertEqual(cf.chroma, czifile['c']) self.assertFalse(cf.is_3d()) self.assertEqual(len(cf.data.shape), 4) + self.assertEqual(cf.shape[0], czifile['h']) + self.assertEqual(cf.shape[1], czifile['w']) def test_write_single_channel_tif(self): ch = 2 @@ -26,3 +31,16 @@ class TestCziImageFileAccess(unittest.TestCase): ) self.assertEqual(cf.data.shape[0:2], mono.data.shape[0:2]) self.assertEqual(cf.data.shape[3], mono.data.shape[2]) + + def test_conform_data_shorter_than_xycz(self): + data = np.random.rand(256, 512) + acc = InMemoryDataAccessor(data) + self.assertEqual( + acc.shape_dict, + {'X': 256, 'Y': 512, 'C': 1, 'Z': 1} + ) + + def test_conform_data_longer_than_xycz(self): + data = np.random.rand(256, 512, 12, 8, 3) + with self.assertRaises(DataShapeError): + acc = InMemoryDataAccessor(data) \ No newline at end of file diff --git a/tests/test_model.py b/tests/test_model.py index 2c1d7c647605e7d75f7702cd3badcc6aa12986cc..9f4ab388190a651536ef29b86fc082f2978f9584 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -25,25 +25,25 @@ class TestCziImageFileAccess(unittest.TestCase): def test_czifile_is_correct_shape(self): model = DummyImageToImageModel() - img, _ = model.infer(self.cf) + img = model.infer(self.cf) w = czifile['w'] h = czifile['h'] self.assertEqual( img.shape, - (h, w), + (h, w, 1, 1), 'Inferred image is not the expected shape' ) self.assertEqual( - img[int(w/2), int(h/2)], + img.data[int(w/2), int(h/2)], 255, 'Middle pixel is not white as expected' ) self.assertEqual( - img[0, 0], + img.data[0, 0], 0, 'First pixel is not black as expected' ) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index fff03bbc2d83a776cb7341d6a049b89757d2cb32..83935d3df306542cfbe0ea6f95efce591f28d5ea 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -10,7 +10,7 @@ class TestGetSessionObject(unittest.TestCase): self.model = DummyImageToImageModel() def test_single_session_instance(self): - result = infer_image_to_image(czifile['path'], self.model, output_path) + result = infer_image_to_image(czifile['path'], self.model, output_path, channel=2) self.assertTrue(result.success) import tifffile @@ -20,7 +20,7 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual( img.shape, - (h, w), + (h, w, 1, 1), 'Inferred image is not the expected shape' )