Skip to content
Snippets Groups Projects
Commit ff6547b2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

ilastik workflow now interprets channel=None to use all input channels

parent 2a542f58
No related branches found
No related tags found
2 merge requests!50Release 2024.06.03,!40ilastik workflow and API allow multichannel inputs
...@@ -57,8 +57,9 @@ ilastik_classifiers = { ...@@ -57,8 +57,9 @@ ilastik_classifiers = {
'px': root / 'ilastik' / 'demo_px.ilp', 'px': root / 'ilastik' / 'demo_px.ilp',
'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp', 'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp',
'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp', 'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp',
'px_color_zstack': root / 'ilastik' / 'px_color_zstack.ilp', 'px_color_zstack': root / 'ilastik' / 'px-3d-color.ilp',
'ob_color_zstack': root / 'ilastik' / 'ob_color_zstack.ilp', 'ob_pxmap_color_zstack': root / 'ilastik' / 'ob-pxmap-color-zstack.ilp',
'ob_seg_color_zstack': root / 'ilastik' / 'ob-seg-color-zstack.ilp',
} }
roiset_test_data = { roiset_test_data = {
......
...@@ -325,8 +325,10 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): ...@@ -325,8 +325,10 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
def setUp(self) -> None: def setUp(self) -> None:
super(TestIlastikOnMultichannelInputs, self).setUp() super(TestIlastikOnMultichannelInputs, self).setUp()
self.pa_px_classifier = ilastik_classifiers['px_color_zstack'] self.pa_px_classifier = ilastik_classifiers['px_color_zstack']
self.pa_ob_classifier = ilastik_classifiers['ob_color_zstack'] self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack']
self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack']
self.pa_input_image = roiset_test_data['multichannel_zstack']['path'] self.pa_input_image = roiset_test_data['multichannel_zstack']['path']
self.pa_mask = roiset_test_data['multichannel_zstack']['mask_path_3d']
def _copy_input_file_to_server(self): def _copy_input_file_to_server(self):
from shutil import copyfile from shutil import copyfile
...@@ -352,7 +354,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): ...@@ -352,7 +354,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
def test_classify_objects(self): def test_classify_objects(self):
pxmap = self.test_classify_pixels() pxmap = self.test_classify_pixels()
img = generate_file_accessor(self.pa_input_image) img = generate_file_accessor(self.pa_input_image)
mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}) mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier})
obmap = mod.infer(img, pxmap)[0] obmap = mod.infer(img, pxmap)[0]
self.assertEqual(obmap.hw, img.hw) self.assertEqual(obmap.hw, img.hw)
self.assertEqual(obmap.nz, img.nz) self.assertEqual(obmap.nz, img.nz)
...@@ -361,7 +363,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): ...@@ -361,7 +363,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
return infer_px_then_ob_model( return infer_px_then_ob_model(
self.pa_input_image, self.pa_input_image,
ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}), ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}),
ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}), ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier}),
output_path, output_path,
channel=channel, channel=channel,
) )
...@@ -369,8 +371,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): ...@@ -369,8 +371,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
def test_workflow(self): def test_workflow(self):
with self.assertRaises(ilm.IlastikInputShapeError): with self.assertRaises(ilm.IlastikInputShapeError):
self._call_workflow(channel=0) self._call_workflow(channel=0)
res = self._call_workflow(channel=None) res = self._call_workflow(channel=None)
acc_input = generate_file_accessor(self.pa_input_image)
acc_obmap = generate_file_accessor(res.object_map_filepath)
self.assertEqual(acc_obmap.hw, acc_input.hw)
self.assertEqual(len(acc_obmap._unique()[1]), 3)
def test_api(self): def test_api(self):
resp_load = self._put( resp_load = self._put(
...@@ -382,7 +388,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): ...@@ -382,7 +388,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
resp_load = self._put( resp_load = self._put(
'ilastik/pxmap_to_obj/load/', 'ilastik/pxmap_to_obj/load/',
query={'project_file': str(self.pa_ob_classifier)}, query={'project_file': str(self.pa_ob_pxmap_classifier)},
) )
self.assertEqual(resp_load.status_code, 200, resp_load.json()) self.assertEqual(resp_load.status_code, 200, resp_load.json())
ob_model_id = resp_load.json()['model_id'] ob_model_id = resp_load.json()['model_id']
...@@ -393,10 +399,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): ...@@ -393,10 +399,12 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass):
'px_model_id': px_model_id, 'px_model_id': px_model_id,
'ob_model_id': ob_model_id, 'ob_model_id': ob_model_id,
'input_filename': self.pa_input_image, 'input_filename': self.pa_input_image,
# 'channel': 0,
} }
) )
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
acc_input = generate_file_accessor(self.pa_input_image)
acc_obmap = generate_file_accessor(resp_infer.json()['object_map_filepath'])
self.assertEqual(acc_obmap.hw, acc_input.hw)
class TestIlastikObjectClassification(unittest.TestCase): class TestIlastikObjectClassification(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment