Skip to content
Snippets Groups Projects
Commit 5bbb91cb authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Added superclass for instance segmentation models and modified ilastik object...

Added superclass for instance segmentation models and modified ilastik object classification to inherit from it
parent 5bd89178
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ import vigra
import extensions.ilastik.conf
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.models import Model, ParameterExpectedError, SemanticSegmentationModel
from model_server.models import Model, InstanceSegmentationModel, ParameterExpectedError, SemanticSegmentationModel
class IlastikModel(Model):
......@@ -114,7 +114,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel):
return InMemoryDataAccessor(data=yxcz), {'success': True}
class IlastikObjectClassifierFromSegmentationModel(IlastikModel):
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_segmentation'
@staticmethod
......@@ -144,3 +144,8 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel):
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_classes(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_classes(img, mask, **kwargs)
obmap, _ = self.infer(img, mask)
return obmap
\ No newline at end of file
......@@ -52,6 +52,7 @@ class ImageToImageModel(Model):
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
pass
class SemanticSegmentationModel(ImageToImageModel):
"""
Model that exposes a method that returns a binary mask for a given input image and pixel class
......@@ -61,6 +62,18 @@ class SemanticSegmentationModel(ImageToImageModel):
def segment(self, img: GenericImageDataAccessor, pixel_class: int, **kwargs) -> (GenericImageDataAccessor, dict):
pass
class InstanceSegmentationModel(ImageToImageModel):
@abstractmethod
def label_instance_classes(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
if not mask.is_mask():
raise InvalidInputImageError('Expecting a binary mask')
if not img.shape == mask.shape:
raise InvalidInputImageError('Expect input image and mask to be the same shape')
class DummyImageToImageModel(ImageToImageModel):
model_id = 'dummy_make_white_square'
......@@ -84,4 +97,7 @@ class CouldNotLoadModelError(Error):
pass
class ParameterExpectedError(Error):
pass
class InvalidInputImageError(Error):
pass
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment