diff --git a/model_server/models.py b/model_server/models.py index f0864147279dbef501138307ae885a742d390757..06bef01c9831cb495b32662c4bbdb8970589ada7 100644 --- a/model_server/models.py +++ b/model_server/models.py @@ -90,8 +90,7 @@ class InstanceSegmentationModel(ImageToImageModel): raise InvalidInputImageError('Expect input image and mask to be the same shape') - -class DummySegmentationModel(SemanticSegmentationModel): +class DummySemanticSegmentationModel(SemanticSegmentationModel): model_id = 'dummy_make_white_square' @@ -111,6 +110,26 @@ class DummySegmentationModel(SemanticSegmentationModel): mask, _ = self.infer(img) return mask +class DummyInstanceSegmentationModel(InstanceSegmentationModel): + + model_id = 'dummy_pass_input_mask' + + def load(self): + return True + + def infer( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor + ) -> (GenericImageDataAccessor, dict): + return mask + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + """ + Returns a trivial segmentation, i.e. the input mask + """ + super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) class Error(Exception): pass diff --git a/tests/test_model.py b/tests/test_model.py index 0d5f98aea4aee70da25ef3edf71b36eebac01a3e..8730c666047ad49660d06f445f8e8ea944cdd007 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,49 +1,56 @@ import unittest from conf.testing import czifile from model_server.accessors import CziImageFileAccessor -from model_server.models import DummySegmentationModel, CouldNotLoadModelError +from model_server.models import DummySemanticSegmentationModel, DummyInstanceSegmentationModel, CouldNotLoadModelError class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) def test_instantiate_model(self): - model = DummySegmentationModel(params=None) + model = DummySemanticSegmentationModel(params=None) self.assertTrue(model.loaded) def test_instantiate_model_with_nondefault_kwarg(self): - model = DummySegmentationModel(autoload=False) + model = DummySemanticSegmentationModel(autoload=False) self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') def test_raise_error_if_cannot_load_model(self): - class UnloadableDummyImageToImageModel(DummySegmentationModel): + class UnloadableDummyImageToImageModel(DummySemanticSegmentationModel): def load(self): return False with self.assertRaises(CouldNotLoadModelError): mi = UnloadableDummyImageToImageModel() - def test_czifile_is_correct_shape(self): - model = DummySegmentationModel() - img, _ = model.infer(self.cf) + def test_dummy_pixel_segmentation(self): + model = DummySemanticSegmentationModel() + img = self.cf.get_one_channel_data(0) + mask = model.label_pixel_class(img) w = czifile['w'] h = czifile['h'] self.assertEqual( - img.shape, + mask.shape, (h, w, 1, 1), 'Inferred image is not the expected shape' ) self.assertEqual( - img.data[int(w/2), int(h/2)], + mask.data[int(w/2), int(h/2)], 255, 'Middle pixel is not white as expected' ) self.assertEqual( - img.data[0, 0], + mask.data[0, 0], 0, 'First pixel is not black as expected' - ) \ No newline at end of file + ) + return img, mask + + def test_dummy_instance_segmentation(self): + img, mask = self.test_dummy_pixel_segmentation() + model = DummyInstanceSegmentationModel() + obmap = model.label_instance_class(img, mask)