Skip to content
Snippets Groups Projects
Commit b17d32c6 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Retained object classification from pixel probabilitie

parent bbc607cc
No related branches found
No related tags found
No related merge requests found
...@@ -6,7 +6,7 @@ import vigra ...@@ -6,7 +6,7 @@ import vigra
import extensions.ilastik.conf import extensions.ilastik.conf
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.models import Model, InstanceSegmentationModel, ParameterExpectedError, SemanticSegmentationModel from model_server.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
class IlastikModel(Model): class IlastikModel(Model):
...@@ -82,23 +82,24 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ...@@ -82,23 +82,24 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold
return InMemoryDataAccessor(mask) return InMemoryDataAccessor(mask)
# TODO: deprecate
class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel): class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_pixel_predictions' model_id = 'ilastik_object_classification_from_segmentation'
@staticmethod @staticmethod
def get_workflow(): def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
return ObjectClassificationWorkflowPrediction return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
assert segmentation_img.is_mask()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
dsi = [ dsi = [
{ {
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data), 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
} }
] ]
...@@ -113,24 +114,28 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel): ...@@ -113,24 +114,28 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel):
) )
return InMemoryDataAccessor(data=yxcz), {'success': True} return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
obmap, _ = self.infer(img, mask)
return obmap
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_segmentation' class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel):
model_id = 'ilastik_object_classification_from_pixel_predictions'
@staticmethod @staticmethod
def get_workflow(): def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
return ObjectClassificationWorkflowBinary return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
assert segmentation_img.is_mask()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
dsi = [ dsi = [
{ {
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), 'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
} }
] ]
...@@ -145,7 +150,26 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment ...@@ -145,7 +150,26 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
) )
return InMemoryDataAccessor(data=yxcz), {'success': True} return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs):
"""
Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is
assigned a class.
:param img: input image
:param pxmap: map of pixel probabilities
:param kwargs:
pixel_classification_channel: channel of pxmap used to segment objects
pixel_classification_thresold: threshold of pxmap used to segment objects
:return:
"""
if not img.shape == pxmap.shape:
raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
# TODO: check that pxmap is in-range
pxch = kwargs.get('pixel_classification_channel', 0)
pxtr = kwargs('pixel_classification_threshold', 0.5)
mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
# super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
obmap, _ = self.infer(img, mask) obmap, _ = self.infer(img, mask)
return obmap return obmap
\ No newline at end of file
...@@ -34,6 +34,8 @@ class GenericImageDataAccessor(ABC): ...@@ -34,6 +34,8 @@ class GenericImageDataAccessor(ABC):
def is_3d(self): def is_3d(self):
return True if self.shape_dict['Z'] > 1 else False return True if self.shape_dict['Z'] > 1 else False
# TODO: implement is_probability
def is_mask(self): def is_mask(self):
return is_mask(self._data) return is_mask(self._data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment