diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index 9b55d6d0eeec83f1bab962f18d81fd275eee5b0a..864413d814bb6bf8073287ac8be5cb93325729e7 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -6,7 +6,7 @@ import vigra import extensions.ilastik.conf from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from model_server.models import Model, InstanceSegmentationModel, ParameterExpectedError, SemanticSegmentationModel +from model_server.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel class IlastikModel(Model): @@ -82,23 +82,24 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold return InMemoryDataAccessor(mask) -# TODO: deprecate -class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel): - model_id = 'ilastik_object_classification_from_pixel_predictions' + +class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): + model_id = 'ilastik_object_classification_from_segmentation' @staticmethod def get_workflow(): - from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction - return ObjectClassificationWorkflowPrediction + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary + return ObjectClassificationWorkflowBinary - def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): + def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): + assert segmentation_img.is_mask() tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') + tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') dsi = [ { 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), + 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), } ] @@ -113,24 +114,28 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel): ) return InMemoryDataAccessor(data=yxcz), {'success': True} + def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): + super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) + obmap, _ = self.infer(img, mask) + return obmap -class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): - model_id = 'ilastik_object_classification_from_segmentation' + +class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel): + model_id = 'ilastik_object_classification_from_pixel_predictions' @staticmethod def get_workflow(): - from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary - return ObjectClassificationWorkflowBinary + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction + return ObjectClassificationWorkflowPrediction - def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - assert segmentation_img.is_mask() + def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') dsi = [ { 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), - 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), + 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), } ] @@ -145,7 +150,26 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment ) return InMemoryDataAccessor(data=yxcz), {'success': True} - def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): - super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) + + def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs): + """ + Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is + assigned a class. + :param img: input image + :param pxmap: map of pixel probabilities + :param kwargs: + pixel_classification_channel: channel of pxmap used to segment objects + pixel_classification_thresold: threshold of pxmap used to segment objects + :return: + """ + if not img.shape == pxmap.shape: + raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') + # TODO: check that pxmap is in-range + pxch = kwargs.get('pixel_classification_channel', 0) + pxtr = kwargs('pixel_classification_threshold', 0.5) + mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) + # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) obmap, _ = self.infer(img, mask) - return obmap \ No newline at end of file + return obmap + + diff --git a/model_server/accessors.py b/model_server/accessors.py index 4f122fb13d8022255f5866daef4646dd3940a00e..04cb35646e0a0cde44dd7d8bc028fd80b7a5ad6f 100644 --- a/model_server/accessors.py +++ b/model_server/accessors.py @@ -34,6 +34,8 @@ class GenericImageDataAccessor(ABC): def is_3d(self): return True if self.shape_dict['Z'] > 1 else False + # TODO: implement is_probability + def is_mask(self): return is_mask(self._data)