diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index ac322e855fee685c12b8ceaf175b434a3cf161b4..d3cc0d9354d020c8ff4a9b61536dada9e80eb34a 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -235,11 +235,27 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag if not img.shape == pxmap.shape: raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape') pxch = kwargs.get('pixel_classification_channel', 0) - pxtr = kwargs('pixel_classification_threshold', 0.5) + pxtr = kwargs.get('pixel_classification_threshold', 0.5) mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr) obmap, _ = self.infer(img, mask) return obmap + def make_instance_segmentation_model(self, px_ch: int): + """ + Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a + second input. + :param px_ch: channel of pixel probability map to use + :return: + InstanceSegmentationModel object + """ + class _Mod(self.__class__, InstanceSegmentationModel): + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + return super().label_instance_class(img, mask, pixel_classification_channel=px_ch) + return _Mod(params={'project_file': self.project_file}) + + class Error(Exception): pass diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 21e3d9af93b7447d6b6414eb09507d0ac1e295e7..15bca41de198bfcd5c4fa714140c5fc8f1f331b0 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -136,6 +136,23 @@ class TestIlastikPixelClassification(unittest.TestCase): ) self.assertEqual(objmap.data.max(), 2) + def test_make_seg_obj_model_from_pxmap_obj(self): + self.test_run_pixel_classifier() + fp = czifile['path'] + pxmap_model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( + {'project_file': ilastik_classifiers['pxmap_to_obj']} + ) + seg_model = pxmap_model.make_instance_segmentation_model(px_ch=0) + objmap = seg_model.label_instance_class(self.mono_image, self.mask) + + self.assertTrue( + write_accessor_data_to_file( + output_path / f'obmap_seg_from_pxmap_{fp.stem}.tif', + objmap, + ) + ) + self.assertEqual(objmap.data.max(), 2) + def test_run_object_classifier_from_segmentation(self): self.test_run_pixel_classifier() fp = czifile['path']