From ffbc837de47d0a8ef575b59f37003513889ac834 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 16 Oct 2023 16:30:04 +0200
Subject: [PATCH] Prototyped object classifier with segmentation

---
 extensions/ilastik/models.py             | 32 ++++++++++++++++++++++++
 extensions/ilastik/tests/test_ilastik.py |  2 ++
 2 files changed, 34 insertions(+)

diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py
index a6daf81a..fb7ea566 100644
--- a/extensions/ilastik/models.py
+++ b/extensions/ilastik/models.py
@@ -106,3 +106,35 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikImageToImageModel)
             [0, 1, 2, 3]
         )
         return InMemoryDataAccessor(data=yxcz), {'success': True}
+
+
+class IlastikObjectClassifierFromSegmentationModel(IlastikImageToImageModel):
+    model_id = 'ilastik_object_classification_from_segmentation'
+
+    @staticmethod
+    def get_workflow():
+        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
+        return ObjectClassificationWorkflowBinary
+
+    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
+        assert segmentation_img.is_mask()
+        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
+        tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
+
+        dsi = [
+            {
+                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
+                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
+            }
+        ]
+
+        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
+
+        assert (len(obmaps) == 1, 'ilastik generated more than one object map')
+
+        yxcz = np.moveaxis(
+            obmaps[0],
+            [1, 2, 3, 0],
+            [0, 1, 2, 3]
+        )
+        return InMemoryDataAccessor(data=yxcz), {'success': True}
diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py
index 348f955e..217b00c9 100644
--- a/extensions/ilastik/tests/test_ilastik.py
+++ b/extensions/ilastik/tests/test_ilastik.py
@@ -199,3 +199,5 @@ class TestIlastikOverApi(TestServerBaseClass):
             }
         )
         self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
+
+    # TODO: test IlastikObjectClassifierFromSegmentationModel when a test model is complete
\ No newline at end of file
-- 
GitLab