diff --git a/model_server/ilastik.py b/model_server/ilastik.py index 7be2269c14123cb8fcaa5747316c32516d8f44a9..aadfad6acbc63f411e549475cef286a91fe823ef 100644 --- a/model_server/ilastik.py +++ b/model_server/ilastik.py @@ -7,6 +7,7 @@ from ilastik.workflows.pixelClassification import PixelClassificationWorkflow from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflow import numpy as np +import vigra from model_server.image import GenericImageDataAccessor, InMemoryDataAccessor from model_server.model import ImageToImageModel, ParameterExpectedError @@ -14,11 +15,6 @@ from model_server.model import ImageToImageModel, ParameterExpectedError class IlastikImageToImageModel(ImageToImageModel): - # workflows = { - # 'pixel_classification': PixelClassificationWorkflow, - # 'object_classification': ObjectClassificationWorkflow, - # } - def __init__(self, params, autoload=True): if 'project_file' not in params or not os.path.exists(params['project_file']): raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file') @@ -56,20 +52,26 @@ class IlastikImageToImageModel(ImageToImageModel): return workflow.objectClassificationApplet.topLevelOperator - # def infer(self, img, channel=None) -> (np.ndarray, dict): - # assert self.operator.Classifier.ready() - class IlastikPixelClassifierModel(IlastikImageToImageModel): workflow = PixelClassificationWorkflow def infer(self, input_img: GenericImageDataAccessor, channel=None) -> (np.ndarray, dict): + tagged_input_data = vigra.taggedView(input_img.data, 'xycz') dsi = [ { - 'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img.data[0:3]), + 'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), } ] - pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) - return InMemoryDataAccessor(data=pxmaps[0]) + pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [1 x w x h x n] + + assert(len(pxmaps) == 1, 'ilastik generated more than on pixel map') + + xycz = np.moveaxis( + pxmaps[0], + [2, 1, 3, 0], + [0, 1, 2, 3] + ) + return InMemoryDataAccessor(data=xycz) diff --git a/model_server/image.py b/model_server/image.py index 863c2ec65f9018757c8cb95170f900c7e1335298..ca015a9cb98e7cdd86c498ea027022269b1a8e76 100644 --- a/model_server/image.py +++ b/model_server/image.py @@ -95,7 +95,12 @@ class CziImageFileAccessor(GenericImageFileAccessor): def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor) -> bool: try: - tifffile.imwrite(fpath, accessor.data) + zcxy = np.moveaxis( + accessor.data, + [3, 2, 0, 1], + [0, 1, 2, 3] + ) + tifffile.imwrite(fpath, zcxy, imagej=True) except: raise FileWriteError(f'Unable to write data to file') return True diff --git a/tests/test_ilastik.py b/tests/test_ilastik.py index 94a7cbbb3da47de8643c06fad22752a83e45d867..ac387a41b600ff5287d284d329fd9752d430b845 100644 --- a/tests/test_ilastik.py +++ b/tests/test_ilastik.py @@ -1,6 +1,9 @@ import unittest + +import numpy as np + from conf.testing import czifile, ilastik, output_path -from model_server.image import CziImageFileAccessor, write_accessor_data_to_file +from model_server.image import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.ilastik import IlastikPixelClassifierModel class TestIlastikPixelClassification(unittest.TestCase): @@ -18,21 +21,37 @@ class TestIlastikPixelClassification(unittest.TestCase): with self.assertRaises(io.UnsupportedOperation): faulthandler.enable(file=sys.stdout) + def test_run_pixel_classifier_on_random_data(self): + model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) + w = 512 + h = 256 + + input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1)) + + pxmap = model.infer(input_img) + self.assertEqual(pxmap.shape, (w, h, 2, 1)) + def test_run_pixel_classifier(self): channel = 2 model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']}) cf = CziImageFileAccessor(czifile['path']) - pxmap = model.infer(cf.get_one_channel_data(channel)) + mono_image = cf.get_one_channel_data(channel) + + self.assertEqual(mono_image.shape_dict['X'], czifile['w']) + self.assertEqual(mono_image.shape_dict['Y'], czifile['h']) + self.assertEqual(mono_image.shape_dict['C'], 1) + self.assertEqual(mono_image.shape_dict['Z'], 1) + + pxmap = model.infer(mono_image) + self.assertEqual(pxmap.shape[0:2], cf.shape[0:2]) + self.assertEqual(pxmap.shape_dict['C'], 2) + self.assertEqual(pxmap.shape_dict['Z'], 1) print(pxmap.shape_dict) + self.assertTrue( + write_accessor_data_to_file( + output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', + pxmap + ) + ) - # self.assertEqual(pxmap.shape[0:2], cf.shape[0:2]) - # self.assertEqual(pxmap.shape_dict['C'], 2) - # self.assertEqual(pxmap.shape_dict['Z'], 1) - # - # self.assertTrue( - # write_accessor_data_to_file( - # output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', - # pxmap - # ) - # ) \ No newline at end of file