diff --git a/model_server/base/models.py b/model_server/base/models.py index 26e8af624eb1a98ae2067785b06a55b40dc45621..9aadc1c9a954eaf766c3bdfe6f9457db92c000ad 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -35,7 +35,7 @@ class Model(ABC): pass @abstractmethod - def infer(self, *args) -> (object, dict): + def infer(self, *args) -> object: """ Abstract method that carries out the computationally intensive step of running data through a model :param args: @@ -58,7 +58,7 @@ class ImageToImageModel(Model): """ @abstractmethod - def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): + def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor: pass @@ -86,6 +86,25 @@ class SemanticSegmentationModel(ImageToImageModel): return PatchStack(data) +class BinaryThresholdSegmentationModel(SemanticSegmentationModel): + """ + Trivial but functional model that labels all pixels above an intensity threshold as class 1 + """ + + def __init__(self, params=None): + self.tr = params['tr'] + self.loaded = self.load() + + def infer(self, acc: GenericImageDataAccessor) -> GenericImageDataAccessor: + return acc.apply(lambda x: x > self.tr) + + def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + return self.infer(acc, **kwargs) + + def load(self): + return True + + class InstanceSegmentationModel(ImageToImageModel): """ Base model that exposes a method that returns an instance classification map for a given input image and mask @@ -127,24 +146,25 @@ class InstanceSegmentationModel(ImageToImageModel): return PatchStack(data) -class BinaryThresholdSegmentationModel(SemanticSegmentationModel): +class PermissiveInstanceSegmentationModel(InstanceSegmentationModel): """ - Trivial but functional model that labels all pixels above an intensity threshold as class 1 + Trivial but functional model that labels all objects as class 1 """ def __init__(self, params=None): - self.tr = params['tr'] - self.loaded = True - - def infer(self, acc: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict): - return acc.apply(lambda x: x > self.tr) - - def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: - return self.infer(acc, **kwargs) + self.loaded = self.load() def load(self): - pass + return True + + def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor: + return mask.apply(lambda x: (1 * (x > 0)).astype(acc.dtype)) + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + super().label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) class Error(Exception): pass diff --git a/tests/base/test_model.py b/tests/base/test_model.py index 8340111c5b45e764b76927c1534a64486d9cc2b6..62d516bbb7b239df5a61f4af2db65f7dbe779c34 100644 --- a/tests/base/test_model.py +++ b/tests/base/test_model.py @@ -1,9 +1,11 @@ import unittest +import numpy as np + import model_server.conf.testing as conf from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceSegmentationModel from model_server.base.accessors import CziImageFileAccessor -from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel +from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel, PermissiveInstanceSegmentationModel czifile = conf.meta['image_files']['czifile'] @@ -64,3 +66,11 @@ class TestCziImageFileAccess(unittest.TestCase): img, mask = self.test_dummy_pixel_segmentation() model = DummyInstanceSegmentationModel() obmap = model.label_instance_class(img, mask) + self.assertTrue(all(obmap.unique()[0] == [0, 1])) + self.assertTrue(all(obmap.unique()[1] > 0)) + + def test_permissive_instance_segmentation(self): + img, mask = self.test_dummy_pixel_segmentation() + model = PermissiveInstanceSegmentationModel() + obmap = model.label_instance_class(img, mask) + self.assertTrue(np.all(mask.data == 255 * obmap.data))