Skip to content
Snippets Groups Projects
Commit c3e3ff61 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test covers instance segmentation base class

parent 16d96d55
No related branches found
No related tags found
No related merge requests found
...@@ -90,8 +90,7 @@ class InstanceSegmentationModel(ImageToImageModel): ...@@ -90,8 +90,7 @@ class InstanceSegmentationModel(ImageToImageModel):
raise InvalidInputImageError('Expect input image and mask to be the same shape') raise InvalidInputImageError('Expect input image and mask to be the same shape')
class DummySemanticSegmentationModel(SemanticSegmentationModel):
class DummySegmentationModel(SemanticSegmentationModel):
model_id = 'dummy_make_white_square' model_id = 'dummy_make_white_square'
...@@ -111,6 +110,26 @@ class DummySegmentationModel(SemanticSegmentationModel): ...@@ -111,6 +110,26 @@ class DummySegmentationModel(SemanticSegmentationModel):
mask, _ = self.infer(img) mask, _ = self.infer(img)
return mask return mask
class DummyInstanceSegmentationModel(InstanceSegmentationModel):
model_id = 'dummy_pass_input_mask'
def load(self):
return True
def infer(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor
) -> (GenericImageDataAccessor, dict):
return mask
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor:
"""
Returns a trivial segmentation, i.e. the input mask
"""
super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs)
return self.infer(img, mask)
class Error(Exception): class Error(Exception):
pass pass
......
import unittest import unittest
from conf.testing import czifile from conf.testing import czifile
from model_server.accessors import CziImageFileAccessor from model_server.accessors import CziImageFileAccessor
from model_server.models import DummySegmentationModel, CouldNotLoadModelError from model_server.models import DummySemanticSegmentationModel, DummyInstanceSegmentationModel, CouldNotLoadModelError
class TestCziImageFileAccess(unittest.TestCase): class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path']) self.cf = CziImageFileAccessor(czifile['path'])
def test_instantiate_model(self): def test_instantiate_model(self):
model = DummySegmentationModel(params=None) model = DummySemanticSegmentationModel(params=None)
self.assertTrue(model.loaded) self.assertTrue(model.loaded)
def test_instantiate_model_with_nondefault_kwarg(self): def test_instantiate_model_with_nondefault_kwarg(self):
model = DummySegmentationModel(autoload=False) model = DummySemanticSegmentationModel(autoload=False)
self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')
def test_raise_error_if_cannot_load_model(self): def test_raise_error_if_cannot_load_model(self):
class UnloadableDummyImageToImageModel(DummySegmentationModel): class UnloadableDummyImageToImageModel(DummySemanticSegmentationModel):
def load(self): def load(self):
return False return False
with self.assertRaises(CouldNotLoadModelError): with self.assertRaises(CouldNotLoadModelError):
mi = UnloadableDummyImageToImageModel() mi = UnloadableDummyImageToImageModel()
def test_czifile_is_correct_shape(self): def test_dummy_pixel_segmentation(self):
model = DummySegmentationModel() model = DummySemanticSegmentationModel()
img, _ = model.infer(self.cf) img = self.cf.get_one_channel_data(0)
mask = model.label_pixel_class(img)
w = czifile['w'] w = czifile['w']
h = czifile['h'] h = czifile['h']
self.assertEqual( self.assertEqual(
img.shape, mask.shape,
(h, w, 1, 1), (h, w, 1, 1),
'Inferred image is not the expected shape' 'Inferred image is not the expected shape'
) )
self.assertEqual( self.assertEqual(
img.data[int(w/2), int(h/2)], mask.data[int(w/2), int(h/2)],
255, 255,
'Middle pixel is not white as expected' 'Middle pixel is not white as expected'
) )
self.assertEqual( self.assertEqual(
img.data[0, 0], mask.data[0, 0],
0, 0,
'First pixel is not black as expected' 'First pixel is not black as expected'
) )
\ No newline at end of file return img, mask
def test_dummy_instance_segmentation(self):
img, mask = self.test_dummy_pixel_segmentation()
model = DummyInstanceSegmentationModel()
obmap = model.label_instance_class(img, mask)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment