From a372c4643622d49f91e483f3b3d56d7122994e94 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 20 Dec 2023 13:27:49 +0100
Subject: [PATCH] Test and example notebook that covers object classification
 from binary segmentation

---
 conf/testing.py                          |  2 +-
 extensions/ilastik/models.py             | 15 +++++++++++----
 extensions/ilastik/tests/test_ilastik.py | 18 +++++++++++++++++-
 model_server/workflows.py                |  2 +-
 4 files changed, 30 insertions(+), 7 deletions(-)

diff --git a/conf/testing.py b/conf/testing.py
index 33cc0859..9a08092f 100644
--- a/conf/testing.py
+++ b/conf/testing.py
@@ -56,7 +56,7 @@ monozstackmask = {
 ilastik_classifiers = {
     'px': root / 'ilastik' / 'demo_px.ilp',
     'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp',
-    'seg_to_obj': root / 'ilastik' / 'new_auto_obj.ilp',
+    'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp',
 }
 
 output_path = root / 'testing_output'
diff --git a/extensions/ilastik/models.py b/extensions/ilastik/models.py
index 864413d8..f39a78f8 100644
--- a/extensions/ilastik/models.py
+++ b/extensions/ilastik/models.py
@@ -92,9 +92,16 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
         return ObjectClassificationWorkflowBinary
 
     def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
-        assert segmentation_img.is_mask()
         tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
-        tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
+        assert segmentation_img.is_mask()
+        if segmentation_img.dtype == 'bool':
+            seg = 255 * segmentation_img.data.astype('uint8')
+            tagged_seg_data = vigra.taggedView(
+                255 * segmentation_img.data.astype('uint8'),
+                'yxcz'
+            )
+        else:
+            tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
 
         dsi = [
             {
@@ -105,7 +112,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
 
         obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
 
-        assert (len(obmaps) == 1, 'ilastik generated more than one object map')
+        assert len(obmaps) == 1, 'ilastik generated more than one object map'
 
         yxcz = np.moveaxis(
             obmaps[0],
@@ -141,7 +148,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
 
         obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
 
-        assert (len(obmaps) == 1, 'ilastik generated more than one object map')
+        assert len(obmaps) == 1, 'ilastik generated more than one object map'
 
         yxcz = np.moveaxis(
             obmaps[0],
diff --git a/extensions/ilastik/tests/test_ilastik.py b/extensions/ilastik/tests/test_ilastik.py
index ed74c43e..1dfbde16 100644
--- a/extensions/ilastik/tests/test_ilastik.py
+++ b/extensions/ilastik/tests/test_ilastik.py
@@ -82,7 +82,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.mono_image = mono_image
         self.mask = mask
 
-    def test_run_object_classifier(self):
+    def test_run_object_classifier_from_pixel_predictions(self):
         self.test_run_pixel_classifier()
         fp = czifile['path']
         model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
@@ -98,6 +98,22 @@ class TestIlastikPixelClassification(unittest.TestCase):
         )
         self.assertEqual(objmap.data.max(), 3)
 
+    def test_run_object_classifier_from_segmentation(self):
+        self.test_run_pixel_classifier()
+        fp = czifile['path']
+        model = ilm.IlastikObjectClassifierFromSegmentationModel(
+            {'project_file': ilastik_classifiers['seg_to_obj']}
+        )
+        objmap = model.label_instance_class(self.mono_image, self.mask)
+
+        self.assertTrue(
+            write_accessor_data_to_file(
+                output_path / f'obmap_from_seg_{fp.stem}.tif',
+                objmap,
+            )
+        )
+        self.assertEqual(objmap.data.max(), 3)
+
     def test_ilastik_pixel_classification_as_workflow(self):
         result = infer_image_to_image(
             czifile['path'],
diff --git a/model_server/workflows.py b/model_server/workflows.py
index c81d2afd..6eb7b50d 100644
--- a/model_server/workflows.py
+++ b/model_server/workflows.py
@@ -43,7 +43,7 @@ def infer_image_to_image(fpi: Path, model: Model, where_output: Path, **kwargs)
     img = generate_file_accessor(fpi).get_one_channel_data(ch)
     ti.click('file_input')
 
-    outdata, _ = model.infer(img)
+    outdata = model.label_pixel_class(img)
     ti.click('inference')
 
     outpath = where_output / (model.model_id + '_' + fpi.stem + '.tif')
-- 
GitLab