Skip to content
Snippets Groups Projects
Commit c3e3ff61 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test covers instance segmentation base class

parent 16d96d55
No related branches found
No related tags found
No related merge requests found
......@@ -90,8 +90,7 @@ class InstanceSegmentationModel(ImageToImageModel):
raise InvalidInputImageError('Expect input image and mask to be the same shape')
class DummySegmentationModel(SemanticSegmentationModel):
class DummySemanticSegmentationModel(SemanticSegmentationModel):
model_id = 'dummy_make_white_square'
......@@ -111,6 +110,26 @@ class DummySegmentationModel(SemanticSegmentationModel):
mask, _ = self.infer(img)
return mask
class DummyInstanceSegmentationModel(InstanceSegmentationModel):
model_id = 'dummy_pass_input_mask'
def load(self):
return True
def infer(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor
) -> (GenericImageDataAccessor, dict):
return mask
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor:
"""
Returns a trivial segmentation, i.e. the input mask
"""
super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs)
return self.infer(img, mask)
class Error(Exception):
pass
......
import unittest
from conf.testing import czifile
from model_server.accessors import CziImageFileAccessor
from model_server.models import DummySegmentationModel, CouldNotLoadModelError
from model_server.models import DummySemanticSegmentationModel, DummyInstanceSegmentationModel, CouldNotLoadModelError
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path'])
def test_instantiate_model(self):
model = DummySegmentationModel(params=None)
model = DummySemanticSegmentationModel(params=None)
self.assertTrue(model.loaded)
def test_instantiate_model_with_nondefault_kwarg(self):
model = DummySegmentationModel(autoload=False)
model = DummySemanticSegmentationModel(autoload=False)
self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')
def test_raise_error_if_cannot_load_model(self):
class UnloadableDummyImageToImageModel(DummySegmentationModel):
class UnloadableDummyImageToImageModel(DummySemanticSegmentationModel):
def load(self):
return False
with self.assertRaises(CouldNotLoadModelError):
mi = UnloadableDummyImageToImageModel()
def test_czifile_is_correct_shape(self):
model = DummySegmentationModel()
img, _ = model.infer(self.cf)
def test_dummy_pixel_segmentation(self):
model = DummySemanticSegmentationModel()
img = self.cf.get_one_channel_data(0)
mask = model.label_pixel_class(img)
w = czifile['w']
h = czifile['h']
self.assertEqual(
img.shape,
mask.shape,
(h, w, 1, 1),
'Inferred image is not the expected shape'
)
self.assertEqual(
img.data[int(w/2), int(h/2)],
mask.data[int(w/2), int(h/2)],
255,
'Middle pixel is not white as expected'
)
self.assertEqual(
img.data[0, 0],
mask.data[0, 0],
0,
'First pixel is not black as expected'
)
\ No newline at end of file
)
return img, mask
def test_dummy_instance_segmentation(self):
img, mask = self.test_dummy_pixel_segmentation()
model = DummyInstanceSegmentationModel()
obmap = model.label_instance_class(img, mask)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment