Skip to content
Snippets Groups Projects

Resolve "ilastik models do not validate dimensionality of input data"

1 file
+ 32
51
Compare changes
  • Side-by-side
  • Inline
@@ -105,22 +105,33 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_segmentation'
@staticmethod
def _make_8bit_mask(nda):
if nda.dtype == 'bool':
return 255 * nda.astype('uint8')
else:
return nda
@staticmethod
def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
assert segmentation_img.is_mask()
if segmentation_img.dtype == 'bool':
seg = 255 * segmentation_img.data.astype('uint8')
if isinstance(input_img, PatchStack):
assert isinstance(segmentation_img, PatchStack)
tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(
255 * segmentation_img.data.astype('uint8'),
'yxcz'
self._make_8bit_mask(segmentation_img.pczyx),
'tczyx'
)
else:
tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_seg_data = vigra.taggedView(
self._make_8bit_mask(segmentation_img.data),
'yxcz'
)
dsi = [
{
@@ -133,12 +144,21 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
assert len(obmaps) == 1, 'ilastik generated more than one object map'
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
if isinstance(input_img, PatchStack):
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
return PatchStack(data=pyxcz), {'success': True}
else:
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
@@ -190,52 +210,13 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
"""
if not img.shape == pxmap.shape:
raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
# TODO: check that pxmap is in-range
pxch = kwargs.get('pixel_classification_channel', 0)
pxtr = kwargs('pixel_classification_threshold', 0.5)
mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
# super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
obmap, _ = self.infer(img, mask)
return obmap
class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
"""
Wrap ilastik object classification for inputs comprising single-object series of raw images and binary
segmentation masks.
"""
def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict):
assert segmentation_acc.is_mask()
if not input_acc.chroma == 1:
raise InvalidInputImageError('Object classifier expects only monochrome patches')
if not input_acc.nz == 1:
raise InvalidInputImageError('Object classifier expects only 2d patches')
tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
# for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C
assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1)
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
return PatchStack(data=pyxcz), {'success': True}
class Error(Exception):
pass
Loading