From c3a8f9153b3b50ad5c556c32b55643fd2e79595f Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 25 Mar 2024 12:44:46 +0100 Subject: [PATCH] Generate an InstanceSegmentationModel from an IlastikObjectClassifierFromPixelPredictionsModel given binarization parameters --- model_server/extensions/ilastik/models.py | 18 +++++++++++++++++- .../extensions/ilastik/tests/test_ilastik.py | 17 +++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index ac322e85..d3cc0d93 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -235,11 +235,27 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag if not img.shape == pxmap.shape: raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') pxch = kwargs.get('pixel_classification_channel', 0) - pxtr = kwargs('pixel_classification_threshold', 0.5) + pxtr = kwargs.get('pixel_classification_threshold', 0.5) mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) obmap, _ = self.infer(img, mask) return obmap + def make_instance_segmentation_model(self, px_ch: int): + """ + Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a + second input. + :param px_ch: channel of pixel probability map to use + :return: + InstanceSegmentationModel object + """ + class _Mod(self.__class__, InstanceSegmentationModel): + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + return super().label_instance_class(img, mask, pixel_classification_channel=px_ch) + return _Mod(params={'project_file': self.project_file}) + + class Error(Exception): pass diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 21e3d9af..15bca41d 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -136,6 +136,23 @@ class TestIlastikPixelClassification(unittest.TestCase): ) self.assertEqual(objmap.data.max(), 2) + def test_make_seg_obj_model_from_pxmap_obj(self): + self.test_run_pixel_classifier() + fp = czifile['path'] + pxmap_model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( + {'project_file': ilastik_classifiers['pxmap_to_obj']} + ) + seg_model = pxmap_model.make_instance_segmentation_model(px_ch=0) + objmap = seg_model.label_instance_class(self.mono_image, self.mask) + + self.assertTrue( + write_accessor_data_to_file( + output_path / f'obmap_seg_from_pxmap_{fp.stem}.tif', + objmap, + ) + ) + self.assertEqual(objmap.data.max(), 2) + def test_run_object_classifier_from_segmentation(self): self.test_run_pixel_classifier() fp = czifile['path'] -- GitLab