diff --git a/conf/testing.py b/conf/testing.py index 33cc08595f00e9176e0207e793b7d65b8439b9df..9a08092ffaa2eafc1868fb545a398b972d86ba95 100644 --- a/conf/testing.py +++ b/conf/testing.py @@ -56,7 +56,7 @@ monozstackmask = { ilastik_classifiers = { 'px': root / 'ilastik' / 'demo_px.ilp', 'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp', - 'seg_to_obj': root / 'ilastik' / 'new_auto_obj.ilp', + 'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp', } output_path = root / 'testing_output' diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index 864413d814bb6bf8073287ac8be5cb93325729e7..f39a78f87363d290db226cfe9ee811e13af778c8 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -92,9 +92,16 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): - assert segmentation_img.is_mask() tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') - tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + assert segmentation_img.is_mask() + if segmentation_img.dtype == 'bool': + seg = 255 * segmentation_img.data.astype('uint8') + tagged_seg_data = vigra.taggedView( + 255 * segmentation_img.data.astype('uint8'), + 'yxcz' + ) + else: + tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') dsi = [ { @@ -105,7 +112,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert (len(obmaps) == 1, 'ilastik generated more than one object map') + assert len(obmaps) == 1, 'ilastik generated more than one object map' yxcz = np.moveaxis( obmaps[0], @@ -141,7 +148,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] - assert (len(obmaps) == 1, 'ilastik generated more than one object map') + assert len(obmaps) == 1, 'ilastik generated more than one object map' yxcz = np.moveaxis( obmaps[0], diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index ed74c43e8293a743876803315f3adfd964dfaa6b..1dfbde1685218234f152ace5968a1aae4a8a9fff 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -82,7 +82,7 @@ class TestIlastikPixelClassification(unittest.TestCase): self.mono_image = mono_image self.mask = mask - def test_run_object_classifier(self): + def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( @@ -98,6 +98,22 @@ class TestIlastikPixelClassification(unittest.TestCase): ) self.assertEqual(objmap.data.max(), 3) + def test_run_object_classifier_from_segmentation(self): + self.test_run_pixel_classifier() + fp = czifile['path'] + model = ilm.IlastikObjectClassifierFromSegmentationModel( + {'project_file': ilastik_classifiers['seg_to_obj']} + ) + objmap = model.label_instance_class(self.mono_image, self.mask) + + self.assertTrue( + write_accessor_data_to_file( + output_path / f'obmap_from_seg_{fp.stem}.tif', + objmap, + ) + ) + self.assertEqual(objmap.data.max(), 3) + def test_ilastik_pixel_classification_as_workflow(self): result = infer_image_to_image( czifile['path'], diff --git a/model_server/workflows.py b/model_server/workflows.py index c81d2afdc3e59e8153e1710524e286961ecd7746..6eb7b50d1b6e78518cc500ac5c1936cd266a333b 100644 --- a/model_server/workflows.py +++ b/model_server/workflows.py @@ -43,7 +43,7 @@ def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs) img = generate_file_accessor(fpi).get_one_channel_data(ch) ti.click('file_input') - outdata, _ = model.infer(img) + outdata = model.label_pixel_class(img) ti.click('inference') outpath = where_output / (model.model_id + '_' + fpi.stem + '.tif')