From a372c4643622d49f91e483f3b3d56d7122994e94 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 20 Dec 2023 13:27:49 +0100 Subject: [PATCH] Test and example notebook that covers object classification from binary segmentation --- conf/testing.py | 2 +- extensions/ilastik/models.py | 15 +++++++++++---- extensions/ilastik/tests/test_ilastik.py | 18 +++++++++++++++++- model_server/workflows.py | 2 +- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/conf/testing.py b/conf/testing.py index 33cc0859..9a08092f 100644 --- a/conf/testing.py +++ b/conf/testing.py @@ -56,7 +56,7 @@ monozstackmask = { ilastik_classifiers = { 'px': root / 'ilastik' / 'demo_px.ilp', 'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp', - 'seg_to_obj': root / 'ilastik' / 'new_auto_obj.ilp', + 'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp', } output_path = root / 'testing_output' diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index 864413d8..f39a78f8 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -92,9 +92,16 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - assert segmentation_img.is_mask() tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + assert segmentation_img.is_mask() + if segmentation_img.dtype == 'bool': + seg = 255 * segmentation_img.data.astype('uint8') + tagged_seg_data = vigra.taggedView( + 255 * segmentation_img.data.astype('uint8'), + 'yxcz' + ) + else: + tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') dsi = [ { @@ -105,7 +112,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert (len(obmaps) == 1, 'ilastik generated more than one object map') + assert len(obmaps) == 1, 'ilastik generated more than one object map' yxcz = np.moveaxis( obmaps[0], @@ -141,7 +148,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert (len(obmaps) == 1, 'ilastik generated more than one object map') + assert len(obmaps) == 1, 'ilastik generated more than one object map' yxcz = np.moveaxis( obmaps[0], diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index ed74c43e..1dfbde16 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -82,7 +82,7 @@ class TestIlastikPixelClassification(unittest.TestCase): self.mono_image = mono_image self.mask = mask - def test_run_object_classifier(self): + def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( @@ -98,6 +98,22 @@ class TestIlastikPixelClassification(unittest.TestCase): ) self.assertEqual(objmap.data.max(), 3) + def test_run_object_classifier_from_segmentation(self): + self.test_run_pixel_classifier() + fp = czifile['path'] + model = ilm.IlastikObjectClassifierFromSegmentationModel( + {'project_file': ilastik_classifiers['seg_to_obj']} + ) + objmap = model.label_instance_class(self.mono_image, self.mask) + + self.assertTrue( + write_accessor_data_to_file( + output_path / f'obmap_from_seg_{fp.stem}.tif', + objmap, + ) + ) + self.assertEqual(objmap.data.max(), 3) + def test_ilastik_pixel_classification_as_workflow(self): result = infer_image_to_image( czifile['path'], diff --git a/model_server/workflows.py b/model_server/workflows.py index c81d2afd..6eb7b50d 100644 --- a/model_server/workflows.py +++ b/model_server/workflows.py @@ -43,7 +43,7 @@ def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) img = generate_file_accessor(fpi).get_one_channel_data(ch) ti.click('file_input') - outdata, _ = model.infer(img) + outdata = model.label_pixel_class(img) ti.click('inference') outpath = where_output / (model.model_id + '_' + fpi.stem + '.tif') -- GitLab