diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 07eba484b24c7b64b729ee1bb71de708ecf1f934..8cba54bcb199ece41610aa5dd138df5eda93ec27 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -162,6 +162,14 @@ class CziImageFileAccessor(GenericImageFileAccessor): except Exception: raise FileAccessorError(f'Unable to access CZI data in {fpath}') + try: + md = cf.metadata(raw=False) + compmet = md['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod'] + except KeyError: + raise InvalidCziCompression('Could not find metadata key OriginalCompressionMethod') + if compmet.upper() != 'UNCOMPRESSED': + raise InvalidCziCompression(f'Unsupported compression method {compmet}') + sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes} if (sd.get('S') and (sd['S'] > 1)) or (sd.get('T') and (sd['T'] > 1)): raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}') @@ -284,6 +292,22 @@ class PatchStack(InMemoryDataAccessor): def count(self): return self.shape_dict['P'] + def export_pyxcz(self, fpath: Path): + tzcyx = np.moveaxis( + self.pyxcz, # yxcz + [0, 4, 3, 1, 2], + [0, 1, 2, 3, 4] + ) + + if self.is_mask(): + if self.dtype == 'bool': + data = (tzcyx * 255).astype('uint8') + else: + data = tzcyx.astype('uint8') + tifffile.imwrite(fpath, data, imagej=True) + else: + tifffile.imwrite(fpath, tzcyx, imagej=True) + @property def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) @@ -313,16 +337,32 @@ class PatchStack(InMemoryDataAccessor): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) -def make_patch_stack_from_file(fpath): # interpret z-dimension as patch position +def make_patch_stack_from_file(fpath): # interpret t-dimension as patch position if not Path(fpath).exists(): raise FileNotFoundError(f'Could not find {fpath}') - pyxc = np.moveaxis( - generate_file_accessor(fpath).data, # yxcz - [0, 1, 2, 3], - [1, 2, 3, 0] + try: + tf = tifffile.TiffFile(fpath) + except Exception: + raise FileAccessorError(f'Unable to access data in {fpath}') + + if len(tf.series) != 1: + raise DataShapeError(f'Expect only one series in {fpath}') + + se = tf.series[0] + + axs = [a for a in se.axes if a in [*'TZCYX']] + sd = dict(zip(axs, se.shape)) + for a in [*'TZC']: + if a not in axs: + sd[a] = 1 + tzcyx = se.asarray().reshape([sd[k] for k in [*'TZCYX']]) + + pyxcz = np.moveaxis( + tzcyx, + [0, 3, 4, 2, 1], + [0, 1, 2, 3, 4], ) - pyxcz = np.expand_dims(pyxc, axis=3) return PatchStack(pyxcz) @@ -345,6 +385,9 @@ class FileWriteError(Error): class InvalidAxisKey(Error): pass +class InvalidCziCompression(Error): + pass + class InvalidDataShape(Error): pass diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index de25566a6571b6a07f8f6e1d453e49a1b13f9c72..d17e69545c0aa746c3ebd17ddc89580ace56eadd 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -1,3 +1,4 @@ +import json import os from pathlib import Path @@ -12,8 +13,17 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat class IlastikModel(Model): - def __init__(self, params, autoload=True): + def __init__(self, params, autoload=True, enforce_embedded=True): + """ + Base class for models that run via ilastik shell API + :param params: + project_file: path to ilastik project file + :param autoload: automatically load model into memory if true + :param enforce_embedded: + raise an error if all input data are not embedded in the project file, i.e. on the filesystem + """ self.project_file = Path(params['project_file']) + self.enforce_embedded = enforce_embedded params['project_file'] = self.project_file.__str__() if self.project_file.is_absolute(): pap = self.project_file @@ -42,6 +52,15 @@ class IlastikModel(Model): args.project = self.project_file_abspath.__str__() shell = app.main(args, init_logging=False) + # validate if inputs are embedded in project file + input_groups = shell.projectManager.currentProjectFile['Input Data']['infos'] + lanes = input_groups.keys() + for ll in lanes: + input_types = input_groups[ll] + for tt in input_types: + ds_loc = input_groups[ll][tt].get('location', False) + if self.enforce_embedded and ds_loc and ds_loc[()] == b'FileSystem': + raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem') if not isinstance(shell.workflow, self.get_workflow()): raise ParameterExpectedError( f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}' @@ -55,12 +74,31 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' operations = ['segment', ] + @property + def model_shape_dict(self): + raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data'] + ax = raw_info['axistags'][()] + ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']] + shape = raw_info['shape'][()] + return dict(zip(ax_keys, shape)) + + @property + def model_chroma(self): + return self.model_shape_dict['C'] + + @property + def model_3d(self): + return self.model_shape_dict['Z'] > 1 + @staticmethod def get_workflow(): from ilastik.workflows import PixelClassificationWorkflow return PixelClassificationWorkflow def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict): + if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): + raise IlastikInputShapeError() + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') dsi = [ { @@ -87,22 +125,33 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): model_id = 'ilastik_object_classification_from_segmentation' + @staticmethod + def _make_8bit_mask(nda): + if nda.dtype == 'bool': + return 255 * nda.astype('uint8') + else: + return nda + @staticmethod def get_workflow(): from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') assert segmentation_img.is_mask() - if segmentation_img.dtype == 'bool': - seg = 255 * segmentation_img.data.astype('uint8') + if isinstance(input_img, PatchStack): + assert isinstance(segmentation_img, PatchStack) + tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx') tagged_seg_data = vigra.taggedView( - 255 * segmentation_img.data.astype('uint8'), - 'yxcz' + self._make_8bit_mask(segmentation_img.pczyx), + 'tczyx' ) else: - tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') + tagged_seg_data = vigra.taggedView( + self._make_8bit_mask(segmentation_img.data), + 'yxcz' + ) dsi = [ { @@ -115,12 +164,21 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment assert len(obmaps) == 1, 'ilastik generated more than one object map' - yxcz = np.moveaxis( - obmaps[0], - [1, 2, 3, 0], - [0, 1, 2, 3] - ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + + if isinstance(input_img, PatchStack): + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] + ) + return PatchStack(data=pyxcz), {'success': True} + else: + yxcz = np.moveaxis( + obmaps[0], + [1, 2, 3, 0], + [0, 1, 2, 3] + ) + return InMemoryDataAccessor(data=yxcz), {'success': True} def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) @@ -172,48 +230,19 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag """ if not img.shape == pxmap.shape: raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') - # TODO: check that pxmap is in-range pxch = kwargs.get('pixel_classification_channel', 0) pxtr = kwargs('pixel_classification_threshold', 0.5) mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) - # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) obmap, _ = self.infer(img, mask) return obmap -class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): - """ - Wrap ilastik object classification for inputs comprising single-object series of raw images and binary - segmentation masks. - """ +class Error(Exception): + pass - def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): - assert segmentation_acc.is_mask() - if not input_acc.chroma == 1: - raise InvalidInputImageError('Object classifier expects only monochrome patches') - if not input_acc.nz == 1: - raise InvalidInputImageError('Object classifier expects only 2d patches') - - tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') - tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') - - dsi = [ - { - 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), - } - ] - - obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - - assert len(obmaps) == 1, 'ilastik generated more than one object map' - - # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C - assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1) - pyxcz = np.moveaxis( - obmaps[0], - [0, 1, 2, 3, 4], - [0, 4, 1, 2, 3] - ) +class IlastikInputEmbedding(Error): + pass - return PatchStack(data=pyxcz), {'success': True} \ No newline at end of file +class IlastikInputShapeError(Error): + """Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape""" + pass \ No newline at end of file diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 32dd137683a1a65a0fcbfc48535dd8a88a221213..7adde1a4cbfcfaa953ff927ca9c0d7f1b4221a9f 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -11,6 +11,9 @@ from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels from tests.test_api import TestServerBaseClass +def _random_int(*args): + return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') + class TestIlastikPixelClassification(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) @@ -83,6 +86,40 @@ class TestIlastikPixelClassification(unittest.TestCase): self.mono_image = mono_image self.mask = mask + def test_pixel_classifier_enforces_input_shape(self): + model = ilm.IlastikPixelClassifierModel( + {'project_file': ilastik_classifiers['px']} + ) + self.assertEqual(model.model_chroma, 1) + self.assertEqual(model.model_3d, False) + + # correct data + self.assertIsInstance( + model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 1, 1) + ) + ), + InMemoryDataAccessor + ) + + # raise except with input of multiple channels + with self.assertRaises(ilm.IlastikInputShapeError): + mask = model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 3, 1) + ) + ) + + # raise except with input of multiple channels + with self.assertRaises(ilm.IlastikInputShapeError): + mask = model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 1, 15) + ) + ) + + def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] @@ -97,7 +134,7 @@ class TestIlastikPixelClassification(unittest.TestCase): objmap, ) ) - self.assertEqual(objmap.data.max(), 3) + self.assertEqual(objmap.data.max(), 2) def test_run_object_classifier_from_segmentation(self): self.test_run_pixel_classifier() @@ -113,7 +150,7 @@ class TestIlastikPixelClassification(unittest.TestCase): objmap, ) ) - self.assertEqual(objmap.data.max(), 3) + self.assertEqual(objmap.data.max(), 2) def test_ilastik_pixel_classification_as_workflow(self): result = classify_pixels( @@ -269,16 +306,18 @@ class TestIlastikObjectClassification(unittest.TestCase): ) ) - self.object_classifier = ilm.PatchStackObjectClassifier( + self.object_classifier = ilm.IlastikObjectClassifierFromSegmentationModel( params={'project_file': ilastik_classifiers['seg_to_obj']} ) + def test_classify_patches(self): raw_patches = self.roiset.get_raw_patches() patch_masks = self.roiset.get_patch_masks() - res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks) + res_patches = self.object_classifier.label_instance_class(raw_patches, patch_masks) self.assertEqual(res_patches.count, self.roiset.count) + res_patches.export_pyxcz(output_path / 'res_patches.tif') for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch - unique = np.unique(res_patches.iat(pi).data) - self.assertEqual(len(unique), 2) - self.assertEqual(unique[0], 0) + la, ct = np.unique(res_patches.iat(pi).data, return_counts=True) + self.assertEqual(np.sum(ct > 1), 2) # exclude single-pixel anomaly + self.assertEqual(la[0], 0) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index bc5b4065f314eadd66692a806cb2356eccdc28e2..d2ca777c988c0101d8f7f4efb09611f4fe7a71f7 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -7,6 +7,9 @@ from model_server.base.accessors import PatchStack, make_patch_stack_from_file, from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor +def _random_int(*args): + return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') + class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: @@ -40,7 +43,7 @@ class TestCziImageFileAccess(unittest.TestCase): nc = 4 nz = 11 c = 3 - cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) + cf = InMemoryDataAccessor(_random_int(h, w, nc, nz)) sc = cf.get_one_channel_data(c) self.assertEqual(sc.shape, (h, w, 1, nz)) @@ -70,7 +73,7 @@ class TestCziImageFileAccess(unittest.TestCase): def test_conform_data_shorter_than_xycz(self): h = 256 w = 512 - data = np.random.rand(h, w, 1) + data = _random_int(h, w, 1) acc = InMemoryDataAccessor(data) self.assertEqual( InMemoryDataAccessor.conform_data(data).shape, @@ -82,7 +85,7 @@ class TestCziImageFileAccess(unittest.TestCase): ) def test_conform_data_longer_than_xycz(self): - data = np.random.rand(256, 512, 12, 8, 3) + data = _random_int(256, 512, 12, 8, 3) with self.assertRaises(DataShapeError): acc = InMemoryDataAccessor(data) @@ -93,7 +96,7 @@ class TestCziImageFileAccess(unittest.TestCase): c = 3 nz = 10 - yxcz = (2**8 * np.random.rand(h, w, c, nz)).astype('uint8') + yxcz = _random_int(h, w, c, nz) acc = InMemoryDataAccessor(yxcz) fp = output_path / f'rand3d.tif' self.assertTrue( @@ -138,7 +141,7 @@ class TestPatchStackAccessor(unittest.TestCase): w = 256 h = 512 n = 4 - acc = PatchStack(np.random.rand(n, h, w, 1, 1)) + acc = PatchStack(_random_int(n, h, w, 1, 1)) self.assertEqual(acc.count, n) self.assertEqual(acc.hw, (h, w)) self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) @@ -147,7 +150,7 @@ class TestPatchStackAccessor(unittest.TestCase): w = 256 h = 512 n = 4 - acc = PatchStack([np.random.rand(h, w, 1, 1) for _ in range(0, n)]) + acc = PatchStack([_random_int(h, w, 1, 1) for _ in range(0, n)]) self.assertEqual(acc.count, n) self.assertEqual(acc.hw, (h, w)) self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1)) @@ -176,8 +179,8 @@ class TestPatchStackAccessor(unittest.TestCase): nz = 5 n = 4 - patches = [np.random.rand(h, w, c, nz) for _ in range(0, n)] - patches.append(np.random.rand(h, 2 * w, c, nz)) + patches = [_random_int(h, w, c, nz) for _ in range(0, n)] + patches.append(_random_int(h, 2 * w, c, nz)) acc = PatchStack(patches) self.assertEqual(acc.count, n + 1) self.assertEqual(acc.hw, (h, 2 * w)) @@ -191,7 +194,15 @@ class TestPatchStackAccessor(unittest.TestCase): n = 4 nz = 15 nc = 2 - acc = PatchStack(np.random.rand(n, h, w, nc, nz)) + acc = PatchStack(_random_int(n, h, w, nc, nz)) self.assertEqual(acc.count, n) self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w)) self.assertEqual(acc.hw, (h, w)) + return acc + + def test_export_pczyx_patch_hyperstack(self): + acc = self.test_pczyx() + fp = output_path / 'patch_hyperstack.tif' + acc.export_pyxcz(fp) + acc2 = make_patch_stack_from_file(fp) + self.assertEqual(acc.shape, acc2.shape) \ No newline at end of file