diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index 8f22c3ad24310af9fb33e17b938a22705a3dbe4e..f3350ab7aca311479c1dba818a82318893ac661c 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -6,7 +6,7 @@ import vigra import extensions.ilastik.conf from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor -from model_server.models import Model, ParameterExpectedError, SemanticSegmentationModel +from model_server.models import Model, InstanceSegmentationModel, ParameterExpectedError, SemanticSegmentationModel class IlastikModel(Model): @@ -114,7 +114,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel): return InMemoryDataAccessor(data=yxcz), {'success': True} -class IlastikObjectClassifierFromSegmentationModel(IlastikModel): +class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel): model_id = 'ilastik_object_classification_from_segmentation' @staticmethod @@ -144,3 +144,8 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel): [0, 1, 2, 3] ) return InMemoryDataAccessor(data=yxcz), {'success': True} + + def label_instance_classes(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): + super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_classes(img, mask, **kwargs) + obmap, _ = self.infer(img, mask) + return obmap \ No newline at end of file diff --git a/model_server/models.py b/model_server/models.py index 07fc31dfdf2ea0b85ad5e63fa0daeb8bd15242fd..087a9975fb908498896c6a17d36ecd9af137ba74 100644 --- a/model_server/models.py +++ b/model_server/models.py @@ -52,6 +52,7 @@ class ImageToImageModel(Model): def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): pass + class SemanticSegmentationModel(ImageToImageModel): """ Model that exposes a method that returns a binary mask for a given input image and pixel class @@ -61,6 +62,18 @@ class SemanticSegmentationModel(ImageToImageModel): def segment(self, img: GenericImageDataAccessor, pixel_class: int, **kwargs) -> (GenericImageDataAccessor, dict): pass + +class InstanceSegmentationModel(ImageToImageModel): + + @abstractmethod + def label_instance_classes(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): + if not mask.is_mask(): + raise InvalidInputImageError('Expecting a binary mask') + if not img.shape == mask.shape: + raise InvalidInputImageError('Expect input image and mask to be the same shape') + + + class DummyImageToImageModel(ImageToImageModel): model_id = 'dummy_make_white_square' @@ -84,4 +97,7 @@ class CouldNotLoadModelError(Error): pass class ParameterExpectedError(Error): + pass + +class InvalidInputImageError(Error): pass \ No newline at end of file