From bbc607ccfe0234f738322c87ea4d40fa78c08049 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 20 Dec 2023 11:30:05 +0100 Subject: [PATCH] Updated ilastik tests to cover semantic and instance segmentation methods --- extensions/ilastik/models.py | 6 +++--- extensions/ilastik/tests/test_ilastik.py | 21 ++++++++++---------- model_server/models.py | 25 ++++++++++++++++++++---- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index f3350ab7..9b55d6d0 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -77,7 +77,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ) return InMemoryDataAccessor(data=yxcz), {'success': True} - def segment(self, img: GenericImageDataAccessor, pixel_class: int, pixel_probability_threshold=0.5): + def label_pixel_class(self, img: GenericImageDataAccessor, pixel_class: int = 0, pixel_probability_threshold=0.5): pxmap, _ = self.infer(img) mask = pxmap.data[:, :, pixel_class, :] > pixel_probability_threshold return InMemoryDataAccessor(mask) @@ -145,7 +145,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment ) return InMemoryDataAccessor(data=yxcz), {'success': True} - def label_instance_classes(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): - super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_classes(img, mask, **kwargs) + def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): + super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs) obmap, _ = self.infer(img, mask) return obmap \ No newline at end of file diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index f4f400a6..ed74c43e 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -35,7 +35,7 @@ class TestIlastikPixelClassification(unittest.TestCase): input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1)) with self.assertRaises(AttributeError): - pxmap, _ = model.infer(input_img) + mask = model.label_pixel_class(input_img) def test_run_pixel_classifier_on_random_data(self): @@ -47,8 +47,8 @@ class TestIlastikPixelClassification(unittest.TestCase): input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1)) - pxmap, _ = model.infer(input_img) - self.assertEqual(pxmap.shape, (h, w, 2, 1)) + mask = model.label_pixel_class(input_img) + self.assertEqual(mask.shape, (h, w, 1, 1)) def test_run_pixel_classifier(self): @@ -66,20 +66,21 @@ class TestIlastikPixelClassification(unittest.TestCase): self.assertEqual(mono_image.shape_dict['C'], 1) self.assertEqual(mono_image.shape_dict['Z'], 1) - pxmap, _ = model.infer(mono_image) + mask = model.label_pixel_class(mono_image) - self.assertEqual(pxmap.shape[0:2], cf.shape[0:2]) - self.assertEqual(pxmap.shape_dict['C'], 2) - self.assertEqual(pxmap.shape_dict['Z'], 1) + self.assertTrue(mask.is_mask()) + self.assertEqual(mask.shape[0:2], cf.shape[0:2]) + self.assertEqual(mask.shape_dict['C'], 1) + self.assertEqual(mask.shape_dict['Z'], 1) self.assertTrue( write_accessor_data_to_file( output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif', - pxmap + mask ) ) self.mono_image = mono_image - self.pxmap = pxmap + self.mask = mask def test_run_object_classifier(self): self.test_run_pixel_classifier() @@ -87,7 +88,7 @@ class TestIlastikPixelClassification(unittest.TestCase): model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( {'project_file': ilastik_classifiers['pxmap_to_obj']} ) - objmap, _ = model.infer(self.mono_image, self.pxmap) + objmap, _ = model.infer(self.mono_image, self.mask) self.assertTrue( write_accessor_data_to_file( diff --git a/model_server/models.py b/model_server/models.py index 087a9975..4f07da57 100644 --- a/model_server/models.py +++ b/model_server/models.py @@ -36,7 +36,12 @@ class Model(ABC): pass @abstractmethod - def infer(self, *args) -> (object, dict): # return json describing inference result + def infer(self, *args) -> (object, dict): + """ + Abstract method that carries out the computationally intensive step of running data through a model + :param args: + :return: + """ pass def reload(self): @@ -55,18 +60,30 @@ class ImageToImageModel(Model): class SemanticSegmentationModel(ImageToImageModel): """ - Model that exposes a method that returns a binary mask for a given input image and pixel class + Base model that exposes a method that returns a binary mask for a given input image and pixel class """ @abstractmethod - def segment(self, img: GenericImageDataAccessor, pixel_class: int, **kwargs) -> (GenericImageDataAccessor, dict): + def label_pixel_class( + self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor: + """ + Given an image, return an image of the same shape where each pixel is assigned to one or more integer classes + """ pass class InstanceSegmentationModel(ImageToImageModel): + """ + Base model that exposes a method that returns an instance classification map for a given input image and mask + """ @abstractmethod - def label_instance_classes(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs): + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + """ + Given an image and a mask of the same size, return a map where each connected object is assigned a class + """ if not mask.is_mask(): raise InvalidInputImageError('Expecting a binary mask') if not img.shape == mask.shape: -- GitLab