Skip to content
Snippets Groups Projects
Commit c3a8f915 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Generate an InstanceSegmentationModel from an...

Generate an InstanceSegmentationModel from an IlastikObjectClassifierFromPixelPredictionsModel given binarization parameters
parent 058633cb
No related branches found
No related tags found
2 merge requests!16Completed (de)serialization of RoiSet,!11Generate an InstanceSegmentationModel from an...
......@@ -235,11 +235,27 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
if not img.shape == pxmap.shape:
raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
pxch = kwargs.get('pixel_classification_channel', 0)
pxtr = kwargs('pixel_classification_threshold', 0.5)
pxtr = kwargs.get('pixel_classification_threshold', 0.5)
mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
obmap, _ = self.infer(img, mask)
return obmap
def make_instance_segmentation_model(self, px_ch: int):
"""
Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a
second input.
:param px_ch: channel of pixel probability map to use
:return:
InstanceSegmentationModel object
"""
class _Mod(self.__class__, InstanceSegmentationModel):
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor:
return super().label_instance_class(img, mask, pixel_classification_channel=px_ch)
return _Mod(params={'project_file': self.project_file})
class Error(Exception):
pass
......
......@@ -136,6 +136,23 @@ class TestIlastikPixelClassification(unittest.TestCase):
)
self.assertEqual(objmap.data.max(), 2)
def test_make_seg_obj_model_from_pxmap_obj(self):
self.test_run_pixel_classifier()
fp = czifile['path']
pxmap_model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
{'project_file': ilastik_classifiers['pxmap_to_obj']}
)
seg_model = pxmap_model.make_instance_segmentation_model(px_ch=0)
objmap = seg_model.label_instance_class(self.mono_image, self.mask)
self.assertTrue(
write_accessor_data_to_file(
output_path / f'obmap_seg_from_pxmap_{fp.stem}.tif',
objmap,
)
)
self.assertEqual(objmap.data.max(), 2)
def test_run_object_classifier_from_segmentation(self):
self.test_run_pixel_classifier()
fp = czifile['path']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment