diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 00224805f7c7a7106125be53ccf36bafad5ec579..48a18758ea486523cb6c38c31a7e7de695d10800 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -105,22 +105,33 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): model_id = 'ilastik_object_classification_from_segmentation' + @staticmethod + def _make_8bit_mask(nda): + if nda.dtype == 'bool': + return 255 * nda.astype('uint8') + else: + return nda + @staticmethod def get_workflow(): from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') assert segmentation_img.is_mask() - if segmentation_img.dtype == 'bool': - seg = 255 * segmentation_img.data.astype('uint8') + if isinstance(input_img, PatchStack): + assert isinstance(segmentation_img, PatchStack) + tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx') tagged_seg_data = vigra.taggedView( - 255 * segmentation_img.data.astype('uint8'), - 'yxcz' + self._make_8bit_mask(segmentation_img.pczyx), + 'tczyx' ) else: - tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') + tagged_seg_data = vigra.taggedView( + self._make_8bit_mask(segmentation_img.data), + 'yxcz' + ) dsi = [ { @@ -133,12 +144,21 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment assert len(obmaps) == 1, 'ilastik generated more than one object map' - yxcz = np.moveaxis( - obmaps[0], - [1, 2, 3, 0], - [0, 1, 2, 3] - ) - return InMemoryDataAccessor(data=yxcz), {'success': True} + + if isinstance(input_img, PatchStack): + pyxcz = np.moveaxis( + obmaps[0], + [0, 1, 2, 3, 4], + [0, 4, 1, 2, 3] + ) + return PatchStack(data=pyxcz), {'success': True} + else: + yxcz = np.moveaxis( + obmaps[0], + [1, 2, 3, 0], + [0, 1, 2, 3] + ) + return InMemoryDataAccessor(data=yxcz), {'success': True} def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) @@ -190,52 +210,13 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag """ if not img.shape == pxmap.shape: raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') - # TODO: check that pxmap is in-range pxch = kwargs.get('pixel_classification_channel', 0) pxtr = kwargs('pixel_classification_threshold', 0.5) mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) - # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) obmap, _ = self.infer(img, mask) return obmap -class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel): - """ - Wrap ilastik object classification for inputs comprising single-object series of raw images and binary - segmentation masks. - """ - - def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict): - assert segmentation_acc.is_mask() - if not input_acc.chroma == 1: - raise InvalidInputImageError('Object classifier expects only monochrome patches') - if not input_acc.nz == 1: - raise InvalidInputImageError('Object classifier expects only 2d patches') - - tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx') - tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx') - - dsi = [ - { - 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), - } - ] - - obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - - assert len(obmaps) == 1, 'ilastik generated more than one object map' - - # for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C - assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1) - pyxcz = np.moveaxis( - obmaps[0], - [0, 1, 2, 3, 4], - [0, 4, 1, 2, 3] - ) - - return PatchStack(data=pyxcz), {'success': True} - class Error(Exception): pass