From ff6547b22098cee4e5335ccb56db485fd2339f71 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 26 Apr 2024 15:49:00 +0200 Subject: [PATCH] ilastik workflow now interprets channel=None to use all input channels --- model_server/conf/testing.py | 5 +++-- .../extensions/ilastik/tests/test_ilastik.py | 20 +++++++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 72b8816d..5fdea529 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -57,8 +57,9 @@ ilastik_classifiers = { 'px': root / 'ilastik' / 'demo_px.ilp', 'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp', 'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp', - 'px_color_zstack': root / 'ilastik' / 'px_color_zstack.ilp', - 'ob_color_zstack': root / 'ilastik' / 'ob_color_zstack.ilp', + 'px_color_zstack': root / 'ilastik' / 'px-3d-color.ilp', + 'ob_pxmap_color_zstack': root / 'ilastik' / 'ob-pxmap-color-zstack.ilp', + 'ob_seg_color_zstack': root / 'ilastik' / 'ob-seg-color-zstack.ilp', } roiset_test_data = { diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index 2406cdbe..18e95a4d 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -325,8 +325,10 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): def setUp(self) -> None: super(TestIlastikOnMultichannelInputs, self).setUp() self.pa_px_classifier = ilastik_classifiers['px_color_zstack'] - self.pa_ob_classifier = ilastik_classifiers['ob_color_zstack'] + self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack'] + self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack'] self.pa_input_image = roiset_test_data['multichannel_zstack']['path'] + self.pa_mask = roiset_test_data['multichannel_zstack']['mask_path_3d'] def _copy_input_file_to_server(self): from shutil import copyfile @@ -352,7 +354,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): def test_classify_objects(self): pxmap = self.test_classify_pixels() img = generate_file_accessor(self.pa_input_image) - mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}) + mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier}) obmap = mod.infer(img, pxmap)[0] self.assertEqual(obmap.hw, img.hw) self.assertEqual(obmap.nz, img.nz) @@ -361,7 +363,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): return infer_px_then_ob_model( self.pa_input_image, ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}), - ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}), + ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier}), output_path, channel=channel, ) @@ -369,8 +371,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): def test_workflow(self): with self.assertRaises(ilm.IlastikInputShapeError): self._call_workflow(channel=0) - res = self._call_workflow(channel=None) + acc_input = generate_file_accessor(self.pa_input_image) + acc_obmap = generate_file_accessor(res.object_map_filepath) + self.assertEqual(acc_obmap.hw, acc_input.hw) + self.assertEqual(len(acc_obmap._unique()[1]), 3) + def test_api(self): resp_load = self._put( @@ -382,7 +388,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): resp_load = self._put( 'ilastik/pxmap_to_obj/load/', - query={'project_file': str(self.pa_ob_classifier)}, + query={'project_file': str(self.pa_ob_pxmap_classifier)}, ) self.assertEqual(resp_load.status_code, 200, resp_load.json()) ob_model_id = resp_load.json()['model_id'] @@ -393,10 +399,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): 'px_model_id': px_model_id, 'ob_model_id': ob_model_id, 'input_filename': self.pa_input_image, - # 'channel': 0, } ) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + acc_input = generate_file_accessor(self.pa_input_image) + acc_obmap = generate_file_accessor(resp_infer.json()['object_map_filepath']) + self.assertEqual(acc_obmap.hw, acc_input.hw) class TestIlastikObjectClassification(unittest.TestCase): -- GitLab