From ff6547b22098cee4e5335ccb56db485fd2339f71 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 26 Apr 2024 15:49:00 +0200
Subject: [PATCH] ilastik workflow now interprets channel=None to use all input
 channels

---
 model_server/conf/testing.py                  |  5 +++--
 .../extensions/ilastik/tests/test_ilastik.py  | 20 +++++++++++++------
 2 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py
index 72b8816d..5fdea529 100644
--- a/model_server/conf/testing.py
+++ b/model_server/conf/testing.py
@@ -57,8 +57,9 @@ ilastik_classifiers = {
     'px': root / 'ilastik' / 'demo_px.ilp',
     'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp',
     'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp',
-    'px_color_zstack': root / 'ilastik' / 'px_color_zstack.ilp',
-    'ob_color_zstack': root / 'ilastik' / 'ob_color_zstack.ilp',
+    'px_color_zstack': root / 'ilastik' / 'px-3d-color.ilp',
+    'ob_pxmap_color_zstack': root / 'ilastik' / 'ob-pxmap-color-zstack.ilp',
+    'ob_seg_color_zstack': root / 'ilastik' / 'ob-seg-color-zstack.ilp',
 }
 
 roiset_test_data = {
diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py
index 2406cdbe..18e95a4d 100644
--- a/model_server/extensions/ilastik/tests/test_ilastik.py
+++ b/model_server/extensions/ilastik/tests/test_ilastik.py
@@ -325,8 +325,10 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
     def setUp(self) -> None:
         super(TestIlastikOnMultichannelInputs, self).setUp()
         self.pa_px_classifier = ilastik_classifiers['px_color_zstack']
-        self.pa_ob_classifier = ilastik_classifiers['ob_color_zstack']
+        self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack']
+        self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack']
         self.pa_input_image = roiset_test_data['multichannel_zstack']['path']
+        self.pa_mask = roiset_test_data['multichannel_zstack']['mask_path_3d']
 
     def _copy_input_file_to_server(self):
         from shutil import copyfile
@@ -352,7 +354,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
     def test_classify_objects(self):
         pxmap = self.test_classify_pixels()
         img = generate_file_accessor(self.pa_input_image)
-        mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier})
+        mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier})
         obmap = mod.infer(img, pxmap)[0]
         self.assertEqual(obmap.hw, img.hw)
         self.assertEqual(obmap.nz, img.nz)
@@ -361,7 +363,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
         return infer_px_then_ob_model(
             self.pa_input_image,
             ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}),
-            ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}),
+            ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier}),
             output_path,
             channel=channel,
         )
@@ -369,8 +371,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
     def test_workflow(self):
         with self.assertRaises(ilm.IlastikInputShapeError):
             self._call_workflow(channel=0)
-
         res = self._call_workflow(channel=None)
+        acc_input = generate_file_accessor(self.pa_input_image)
+        acc_obmap = generate_file_accessor(res.object_map_filepath)
+        self.assertEqual(acc_obmap.hw, acc_input.hw)
+        self.assertEqual(len(acc_obmap._unique()[1]), 3)
+
 
     def test_api(self):
         resp_load = self._put(
@@ -382,7 +388,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
 
         resp_load = self._put(
             'ilastik/pxmap_to_obj/load/',
-            query={'project_file': str(self.pa_ob_classifier)},
+            query={'project_file': str(self.pa_ob_pxmap_classifier)},
         )
         self.assertEqual(resp_load.status_code, 200, resp_load.json())
         ob_model_id = resp_load.json()['model_id']
@@ -393,10 +399,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
                 'px_model_id': px_model_id,
                 'ob_model_id': ob_model_id,
                 'input_filename': self.pa_input_image,
-                # 'channel': 0,
             }
         )
         self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
+        acc_input = generate_file_accessor(self.pa_input_image)
+        acc_obmap = generate_file_accessor(resp_infer.json()['object_map_filepath'])
+        self.assertEqual(acc_obmap.hw, acc_input.hw)
 
 
 class TestIlastikObjectClassification(unittest.TestCase):
-- 
GitLab