Skip to content
Snippets Groups Projects
Commit ff6547b2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

ilastik workflow now interprets channel=None to use all input channels

parent 2a542f58
No related branches found
No related tags found
2 merge requests!50Release 2024.06.03,!40ilastik workflow and API allow multichannel inputs
......@@ -57,8 +57,9 @@ ilastik_classifiers = {
'px': root / 'ilastik' / 'demo_px.ilp',
'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp',
'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp',
'px_color_zstack': root / 'ilastik' / 'px_color_zstack.ilp',
'ob_color_zstack': root / 'ilastik' / 'ob_color_zstack.ilp',
'px_color_zstack': root / 'ilastik' / 'px-3d-color.ilp',
'ob_pxmap_color_zstack': root / 'ilastik' / 'ob-pxmap-color-zstack.ilp',
'ob_seg_color_zstack': root / 'ilastik' / 'ob-seg-color-zstack.ilp',
}
roiset_test_data = {
......
......@@ -325,8 +325,10 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
def setUp(self) -> None:
super(TestIlastikOnMultichannelInputs, self).setUp()
self.pa_px_classifier = ilastik_classifiers['px_color_zstack']
self.pa_ob_classifier = ilastik_classifiers['ob_color_zstack']
self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack']
self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack']
self.pa_input_image = roiset_test_data['multichannel_zstack']['path']
self.pa_mask = roiset_test_data['multichannel_zstack']['mask_path_3d']
def _copy_input_file_to_server(self):
from shutil import copyfile
......@@ -352,7 +354,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
def test_classify_objects(self):
pxmap = self.test_classify_pixels()
img = generate_file_accessor(self.pa_input_image)
mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier})
mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier})
obmap = mod.infer(img, pxmap)[0]
self.assertEqual(obmap.hw, img.hw)
self.assertEqual(obmap.nz, img.nz)
......@@ -361,7 +363,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
return infer_px_then_ob_model(
self.pa_input_image,
ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}),
ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}),
ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier}),
output_path,
channel=channel,
)
......@@ -369,8 +371,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
def test_workflow(self):
with self.assertRaises(ilm.IlastikInputShapeError):
self._call_workflow(channel=0)
res = self._call_workflow(channel=None)
acc_input = generate_file_accessor(self.pa_input_image)
acc_obmap = generate_file_accessor(res.object_map_filepath)
self.assertEqual(acc_obmap.hw, acc_input.hw)
self.assertEqual(len(acc_obmap._unique()[1]), 3)
def test_api(self):
resp_load = self._put(
......@@ -382,7 +388,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
resp_load = self._put(
'ilastik/pxmap_to_obj/load/',
query={'project_file': str(self.pa_ob_classifier)},
query={'project_file': str(self.pa_ob_pxmap_classifier)},
)
self.assertEqual(resp_load.status_code, 200, resp_load.json())
ob_model_id = resp_load.json()['model_id']
......@@ -393,10 +399,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
'px_model_id': px_model_id,
'ob_model_id': ob_model_id,
'input_filename': self.pa_input_image,
# 'channel': 0,
}
)
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
acc_input = generate_file_accessor(self.pa_input_image)
acc_obmap = generate_file_accessor(resp_infer.json()['object_map_filepath'])
self.assertEqual(acc_obmap.hw, acc_input.hw)
class TestIlastikObjectClassification(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment