Skip to content
Snippets Groups Projects
Commit fc16982f authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Added trivial object classifier, too

parent efce29dc
No related branches found
No related tags found
No related merge requests found
......@@ -35,7 +35,7 @@ class Model(ABC):
pass
@abstractmethod
def infer(self, *args) -> (object, dict):
def infer(self, *args) -> object:
"""
Abstract method that carries out the computationally intensive step of running data through a model
:param args:
......@@ -58,7 +58,7 @@ class ImageToImageModel(Model):
"""
@abstractmethod
def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
def infer(self, img: GenericImageDataAccessor) -> GenericImageDataAccessor:
pass
......@@ -86,6 +86,25 @@ class SemanticSegmentationModel(ImageToImageModel):
return PatchStack(data)
class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
"""
Trivial but functional model that labels all pixels above an intensity threshold as class 1
"""
def __init__(self, params=None):
self.tr = params['tr']
self.loaded = self.load()
def infer(self, acc: GenericImageDataAccessor) -> GenericImageDataAccessor:
return acc.apply(lambda x: x > self.tr)
def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
return self.infer(acc, **kwargs)
def load(self):
return True
class InstanceSegmentationModel(ImageToImageModel):
"""
Base model that exposes a method that returns an instance classification map for a given input image and mask
......@@ -127,24 +146,25 @@ class InstanceSegmentationModel(ImageToImageModel):
return PatchStack(data)
class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
class PermissiveInstanceSegmentationModel(InstanceSegmentationModel):
"""
Trivial but functional model that labels all pixels above an intensity threshold as class 1
Trivial but functional model that labels all objects as class 1
"""
def __init__(self, params=None):
self.tr = params['tr']
self.loaded = True
def infer(self, acc: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
return acc.apply(lambda x: x > self.tr)
def label_pixel_class(self, acc: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
return self.infer(acc, **kwargs)
self.loaded = self.load()
def load(self):
pass
return True
def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor:
return mask.apply(lambda x: (1 * (x > 0)).astype(acc.dtype))
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor:
super().label_instance_class(img, mask, **kwargs)
return self.infer(img, mask)
class Error(Exception):
pass
......
import unittest
import numpy as np
import model_server.conf.testing as conf
from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceSegmentationModel
from model_server.base.accessors import CziImageFileAccessor
from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel
from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel, PermissiveInstanceSegmentationModel
czifile = conf.meta['image_files']['czifile']
......@@ -64,3 +66,11 @@ class TestCziImageFileAccess(unittest.TestCase):
img, mask = self.test_dummy_pixel_segmentation()
model = DummyInstanceSegmentationModel()
obmap = model.label_instance_class(img, mask)
self.assertTrue(all(obmap.unique()[0] == [0, 1]))
self.assertTrue(all(obmap.unique()[1] > 0))
def test_permissive_instance_segmentation(self):
img, mask = self.test_dummy_pixel_segmentation()
model = PermissiveInstanceSegmentationModel()
obmap = model.label_instance_class(img, mask)
self.assertTrue(np.all(mask.data == 255 * obmap.data))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment