diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index d98a41eb203a3646ae4355e5af249268bb883ec9..2406cdbe1e5f499139847bba3c42b3fac8f1dada 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -7,6 +7,7 @@ import numpy as np from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file from model_server.extensions.ilastik import models as ilm +from model_server.extensions.ilastik.workflows import infer_px_then_ob_model from model_server.base.models import InvalidObjectLabelsError from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels @@ -325,7 +326,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): super(TestIlastikOnMultichannelInputs, self).setUp() self.pa_px_classifier = ilastik_classifiers['px_color_zstack'] self.pa_ob_classifier = ilastik_classifiers['ob_color_zstack'] - self.input_image = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + self.pa_input_image = roiset_test_data['multichannel_zstack']['path'] def _copy_input_file_to_server(self): from shutil import copyfile @@ -340,21 +341,62 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): ) def test_classify_pixels(self): - self.assertGreater(self.input_image.chroma, 1) + img = generate_file_accessor(self.pa_input_image) + self.assertGreater(img.chroma, 1) mod = ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}) - pxmap = mod.infer(self.input_image)[0] - self.assertEqual(pxmap.hw, self.input_image.hw) - self.assertEqual(pxmap.nz, self.input_image.nz) + pxmap = mod.infer(img)[0] + self.assertEqual(pxmap.hw, img.hw) + self.assertEqual(pxmap.nz, img.nz) return pxmap def test_classify_objects(self): pxmap = self.test_classify_pixels() + img = generate_file_accessor(self.pa_input_image) mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}) - obmap = mod.infer(self.input_image, pxmap)[0] - self.assertEqual(obmap.hw, self.input_image.hw) - self.assertEqual(obmap.nz, self.input_image.nz) + obmap = mod.infer(img, pxmap)[0] + self.assertEqual(obmap.hw, img.hw) + self.assertEqual(obmap.nz, img.nz) + + def _call_workflow(self, channel): + return infer_px_then_ob_model( + self.pa_input_image, + ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}), + ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}), + output_path, + channel=channel, + ) + + def test_workflow(self): + with self.assertRaises(ilm.IlastikInputShapeError): + self._call_workflow(channel=0) + + res = self._call_workflow(channel=None) + def test_api(self): + resp_load = self._put( + 'ilastik/seg/load/', + query={'project_file': str(self.pa_px_classifier)}, + ) + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + px_model_id = resp_load.json()['model_id'] + + resp_load = self._put( + 'ilastik/pxmap_to_obj/load/', + query={'project_file': str(self.pa_ob_classifier)}, + ) + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + ob_model_id = resp_load.json()['model_id'] + resp_infer = self._put( + 'ilastik/pixel_then_object_classification/infer/', + query={ + 'px_model_id': px_model_id, + 'ob_model_id': ob_model_id, + 'input_filename': self.pa_input_image, + # 'channel': 0, + } + ) + self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) class TestIlastikObjectClassification(unittest.TestCase):