From b9037651b97dd8ddd58f985db2fdb260fd2e72b4 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 26 Apr 2024 15:05:13 +0200 Subject: [PATCH] Test coverage of ilastik workflow and API endpoint that currently restricts multichannel inputs --- .../extensions/ilastik/tests/test_ilastik.py | 58 ++++++++++++++++--- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index d98a41eb..2406cdbe 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -7,6 +7,7 @@ import numpy as np from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file from model_server.extensions.ilastik import models as ilm +from model_server.extensions.ilastik.workflows import infer_px_then_ob_model from model_server.base.models import InvalidObjectLabelsError from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels @@ -325,7 +326,7 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): super(TestIlastikOnMultichannelInputs, self).setUp() self.pa_px_classifier = ilastik_classifiers['px_color_zstack'] self.pa_ob_classifier = ilastik_classifiers['ob_color_zstack'] - self.input_image = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) + self.pa_input_image = roiset_test_data['multichannel_zstack']['path'] def _copy_input_file_to_server(self): from shutil import copyfile @@ -340,21 +341,62 @@ class TestIlastikOnMultichannelInputs(TestServerBaseClass): ) def test_classify_pixels(self): - self.assertGreater(self.input_image.chroma, 1) + img = generate_file_accessor(self.pa_input_image) + self.assertGreater(img.chroma, 1) mod = ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}) - pxmap = mod.infer(self.input_image)[0] - self.assertEqual(pxmap.hw, self.input_image.hw) - self.assertEqual(pxmap.nz, self.input_image.nz) + pxmap = mod.infer(img)[0] + self.assertEqual(pxmap.hw, img.hw) + self.assertEqual(pxmap.nz, img.nz) return pxmap def test_classify_objects(self): pxmap = self.test_classify_pixels() + img = generate_file_accessor(self.pa_input_image) mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}) - obmap = mod.infer(self.input_image, pxmap)[0] - self.assertEqual(obmap.hw, self.input_image.hw) - self.assertEqual(obmap.nz, self.input_image.nz) + obmap = mod.infer(img, pxmap)[0] + self.assertEqual(obmap.hw, img.hw) + self.assertEqual(obmap.nz, img.nz) + + def _call_workflow(self, channel): + return infer_px_then_ob_model( + self.pa_input_image, + ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}), + ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_classifier}), + output_path, + channel=channel, + ) + + def test_workflow(self): + with self.assertRaises(ilm.IlastikInputShapeError): + self._call_workflow(channel=0) + + res = self._call_workflow(channel=None) + def test_api(self): + resp_load = self._put( + 'ilastik/seg/load/', + query={'project_file': str(self.pa_px_classifier)}, + ) + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + px_model_id = resp_load.json()['model_id'] + + resp_load = self._put( + 'ilastik/pxmap_to_obj/load/', + query={'project_file': str(self.pa_ob_classifier)}, + ) + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + ob_model_id = resp_load.json()['model_id'] + resp_infer = self._put( + 'ilastik/pixel_then_object_classification/infer/', + query={ + 'px_model_id': px_model_id, + 'ob_model_id': ob_model_id, + 'input_filename': self.pa_input_image, + # 'channel': 0, + } + ) + self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) class TestIlastikObjectClassification(unittest.TestCase): -- GitLab