Skip to content
Snippets Groups Projects
Commit b17d32c6 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Retained object classification from pixel probabilitie

parent bbc607cc
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ import vigra
import extensions.ilastik.conf
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.models import Model, InstanceSegmentationModel, ParameterExpectedError, SemanticSegmentationModel
from model_server.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
class IlastikModel(Model):
......@@ -82,23 +82,24 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold
return InMemoryDataAccessor(mask)
# TODO: deprecate
class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel):
model_id = 'ilastik_object_classification_from_pixel_predictions'
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_segmentation'
@staticmethod
def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
return ObjectClassificationWorkflowPrediction
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
assert segmentation_img.is_mask()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
}
]
......@@ -113,24 +114,28 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel):
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
obmap, _ = self.infer(img, mask)
return obmap
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_segmentation'
class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel):
model_id = 'ilastik_object_classification_from_pixel_predictions'
@staticmethod
def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
return ObjectClassificationWorkflowBinary
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
assert segmentation_img.is_mask()
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
}
]
......@@ -145,7 +150,26 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs):
"""
Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is
assigned a class.
:param img: input image
:param pxmap: map of pixel probabilities
:param kwargs:
pixel_classification_channel: channel of pxmap used to segment objects
pixel_classification_thresold: threshold of pxmap used to segment objects
:return:
"""
if not img.shape == pxmap.shape:
raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
# TODO: check that pxmap is in-range
pxch = kwargs.get('pixel_classification_channel', 0)
pxtr = kwargs('pixel_classification_threshold', 0.5)
mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
# super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
obmap, _ = self.infer(img, mask)
return obmap
\ No newline at end of file
return obmap
......@@ -34,6 +34,8 @@ class GenericImageDataAccessor(ABC):
def is_3d(self):
return True if self.shape_dict['Z'] > 1 else False
# TODO: implement is_probability
def is_mask(self):
return is_mask(self._data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment