From c3e3ff61c697b32900577420a7008111d3148eb0 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 20 Dec 2023 15:30:53 +0100 Subject: [PATCH] Test covers instance segmentation base class --- model_server/models.py | 23 +++++++++++++++++++++-- tests/test_model.py | 29 ++++++++++++++++++----------- 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/model_server/models.py b/model_server/models.py index f0864147..06bef01c 100644 --- a/model_server/models.py +++ b/model_server/models.py @@ -90,8 +90,7 @@ class InstanceSegmentationModel(ImageToImageModel): raise InvalidInputImageError('Expect input image and mask to be the same shape') - -class DummySegmentationModel(SemanticSegmentationModel): +class DummySemanticSegmentationModel(SemanticSegmentationModel): model_id = 'dummy_make_white_square' @@ -111,6 +110,26 @@ class DummySegmentationModel(SemanticSegmentationModel): mask, _ = self.infer(img) return mask +class DummyInstanceSegmentationModel(InstanceSegmentationModel): + + model_id = 'dummy_pass_input_mask' + + def load(self): + return True + + def infer( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor + ) -> (GenericImageDataAccessor, dict): + return mask + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + """ + Returns a trivial segmentation, i.e. the input mask + """ + super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs) + return self.infer(img, mask) class Error(Exception): pass diff --git a/tests/test_model.py b/tests/test_model.py index 0d5f98ae..8730c666 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,49 +1,56 @@ import unittest from conf.testing import czifile from model_server.accessors import CziImageFileAccessor -from model_server.models import DummySegmentationModel, CouldNotLoadModelError +from model_server.models import DummySemanticSegmentationModel, DummyInstanceSegmentationModel, CouldNotLoadModelError class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) def test_instantiate_model(self): - model = DummySegmentationModel(params=None) + model = DummySemanticSegmentationModel(params=None) self.assertTrue(model.loaded) def test_instantiate_model_with_nondefault_kwarg(self): - model = DummySegmentationModel(autoload=False) + model = DummySemanticSegmentationModel(autoload=False) self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') def test_raise_error_if_cannot_load_model(self): - class UnloadableDummyImageToImageModel(DummySegmentationModel): + class UnloadableDummyImageToImageModel(DummySemanticSegmentationModel): def load(self): return False with self.assertRaises(CouldNotLoadModelError): mi = UnloadableDummyImageToImageModel() - def test_czifile_is_correct_shape(self): - model = DummySegmentationModel() - img, _ = model.infer(self.cf) + def test_dummy_pixel_segmentation(self): + model = DummySemanticSegmentationModel() + img = self.cf.get_one_channel_data(0) + mask = model.label_pixel_class(img) w = czifile['w'] h = czifile['h'] self.assertEqual( - img.shape, + mask.shape, (h, w, 1, 1), 'Inferred image is not the expected shape' ) self.assertEqual( - img.data[int(w/2), int(h/2)], + mask.data[int(w/2), int(h/2)], 255, 'Middle pixel is not white as expected' ) self.assertEqual( - img.data[0, 0], + mask.data[0, 0], 0, 'First pixel is not black as expected' - ) \ No newline at end of file + ) + return img, mask + + def test_dummy_instance_segmentation(self): + img, mask = self.test_dummy_pixel_segmentation() + model = DummyInstanceSegmentationModel() + obmap = model.label_instance_class(img, mask) -- GitLab