From ffbc837de47d0a8ef575b59f37003513889ac834 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Mon, 16 Oct 2023 16:30:04 +0200 Subject: [PATCH] Prototyped object classifier with segmentation --- extensions/ilastik/models.py | 32 ++++++++++++++++++++++++ extensions/ilastik/tests/test_ilastik.py | 2 ++ 2 files changed, 34 insertions(+) diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py index a6daf81a..fb7ea566 100644 --- a/extensions/ilastik/models.py +++ b/extensions/ilastik/models.py @@ -106,3 +106,35 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel) [0, 1, 2, 3] ) return InMemoryDataAccessor(data=yxcz), {'success': True} + + +class IlastikObjectClassifierFromSegmentationModel(IlastikImageToImageModel): + model_id = 'ilastik_object_classification_from_segmentation' + + @staticmethod + def get_workflow(): + from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary + return ObjectClassificationWorkflowBinary + + def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): + assert segmentation_img.is_mask() + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') + tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz') + + dsi = [ + { + 'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data), + 'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data), + } + ] + + obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n] + + assert (len(obmaps) == 1, 'ilastik generated more than one object map') + + yxcz = np.moveaxis( + obmaps[0], + [1, 2, 3, 0], + [0, 1, 2, 3] + ) + return InMemoryDataAccessor(data=yxcz), {'success': True} diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py index 348f955e..217b00c9 100644 --- a/extensions/ilastik/tests/test_ilastik.py +++ b/extensions/ilastik/tests/test_ilastik.py @@ -199,3 +199,5 @@ class TestIlastikOverApi(TestServerBaseClass): } ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + + # TODO: test IlastikObjectClassifierFromSegmentationModel when a test model is complete \ No newline at end of file -- GitLab