diff --git a/model_server/conf/testing.py b/model_server/conf/testing.py index 97a13aaf75192c76a75ace8c8bcccf3678370a5d..5fdea52941b420f18849aef78454a3b9a200da45 100644 --- a/model_server/conf/testing.py +++ b/model_server/conf/testing.py @@ -57,6 +57,9 @@ ilastik_classifiers = { 'px': root / 'ilastik' / 'demo_px.ilp', 'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp', 'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp', + 'px_color_zstack': root / 'ilastik' / 'px-3d-color.ilp', + 'ob_pxmap_color_zstack': root / 'ilastik' / 'ob-pxmap-color-zstack.ilp', + 'ob_seg_color_zstack': root / 'ilastik' / 'ob-seg-color-zstack.ilp', } roiset_test_data = { diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 2dea13329513d0256baf7e9f2c698d4731301893..2c92c026c08d78052b21b982faa7aa4b4b7e1ec0 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -104,8 +104,15 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): return [l.decode() for l in h5['PixelClassification/LabelNames'][()]] def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict): - if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): - raise IlastikInputShapeError() + if self.model_chroma != input_img.chroma: + raise IlastikInputShapeError( + f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}' + ) + if self.model_3d != input_img.is_3d(): + if self.model_3d: + raise IlastikInputShapeError(f'Model is 3D but input image is 2D') + else: + raise IlastikInputShapeError(f'Model is 2D but input image is 3D') tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') dsi = [ @@ -164,7 +171,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): if self.model_chroma != input_img.chroma: raise IlastikInputShapeError( - f'Model {self} expects {self.model_chroma} input channels but received only {input_img.chroma}' + f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}' ) if self.model_3d != input_img.is_3d(): if self.model_3d: @@ -229,8 +236,15 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag return ObjectClassificationWorkflowPrediction def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): - if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): - raise IlastikInputShapeError() + if self.model_chroma != input_img.chroma: + raise IlastikInputShapeError( + f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}' + ) + if self.model_3d != input_img.is_3d(): + if self.model_3d: + raise IlastikInputShapeError(f'Model is 3D but input image is 2D') + else: + raise IlastikInputShapeError(f'Model is 2D but input image is 3D') if isinstance(input_img, PatchStack): assert isinstance(pxmap_img, PatchStack) diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index a1a4983efbded47e54804521c04394c55795439c..18e95a4df44279b452bc11226700375b63c15381 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -7,6 +7,7 @@ import numpy as np from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file from model_server.extensions.ilastik import models as ilm +from model_server.extensions.ilastik.workflows import infer_px_then_ob_model from model_server.base.models import InvalidObjectLabelsError from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels @@ -320,6 +321,92 @@ class TestIlastikOverApi(TestServerBaseClass): self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) +class TestIlastikOnMultichannelInputs(TestServerBaseClass): + def setUp(self) -> None: + super(TestIlastikOnMultichannelInputs, self).setUp() + self.pa_px_classifier = ilastik_classifiers['px_color_zstack'] + self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack'] + self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack'] + self.pa_input_image = roiset_test_data['multichannel_zstack']['path'] + self.pa_mask = roiset_test_data['multichannel_zstack']['mask_path_3d'] + + def _copy_input_file_to_server(self): + from shutil import copyfile + + pa_data = roiset_test_data['multichannel_zstack']['path'] + resp = self._get('paths') + pa = resp.json()['inbound_images'] + outpath = pathlib.Path(pa) / pa_data.name + copyfile( + czifile['path'], + outpath + ) + + def test_classify_pixels(self): + img = generate_file_accessor(self.pa_input_image) + self.assertGreater(img.chroma, 1) + mod = ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}) + pxmap = mod.infer(img)[0] + self.assertEqual(pxmap.hw, img.hw) + self.assertEqual(pxmap.nz, img.nz) + return pxmap + + def test_classify_objects(self): + pxmap = self.test_classify_pixels() + img = generate_file_accessor(self.pa_input_image) + mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier}) + obmap = mod.infer(img, pxmap)[0] + self.assertEqual(obmap.hw, img.hw) + self.assertEqual(obmap.nz, img.nz) + + def _call_workflow(self, channel): + return infer_px_then_ob_model( + self.pa_input_image, + ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}), + ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier}), + output_path, + channel=channel, + ) + + def test_workflow(self): + with self.assertRaises(ilm.IlastikInputShapeError): + self._call_workflow(channel=0) + res = self._call_workflow(channel=None) + acc_input = generate_file_accessor(self.pa_input_image) + acc_obmap = generate_file_accessor(res.object_map_filepath) + self.assertEqual(acc_obmap.hw, acc_input.hw) + self.assertEqual(len(acc_obmap._unique()[1]), 3) + + + def test_api(self): + resp_load = self._put( + 'ilastik/seg/load/', + query={'project_file': str(self.pa_px_classifier)}, + ) + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + px_model_id = resp_load.json()['model_id'] + + resp_load = self._put( + 'ilastik/pxmap_to_obj/load/', + query={'project_file': str(self.pa_ob_pxmap_classifier)}, + ) + self.assertEqual(resp_load.status_code, 200, resp_load.json()) + ob_model_id = resp_load.json()['model_id'] + + resp_infer = self._put( + 'ilastik/pixel_then_object_classification/infer/', + query={ + 'px_model_id': px_model_id, + 'ob_model_id': ob_model_id, + 'input_filename': self.pa_input_image, + } + ) + self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode()) + acc_input = generate_file_accessor(self.pa_input_image) + acc_obmap = generate_file_accessor(resp_infer.json()['object_map_filepath']) + self.assertEqual(acc_obmap.hw, acc_input.hw) + + class TestIlastikObjectClassification(unittest.TestCase): def setUp(self): stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path']) diff --git a/model_server/extensions/ilastik/workflows.py b/model_server/extensions/ilastik/workflows.py index c5b1575f00c1e27f1fc977ebd709dff48a16681d..6f913f652d013145fdcc2a9f91d3ed0db4234eba 100644 --- a/model_server/extensions/ilastik/workflows.py +++ b/model_server/extensions/ilastik/workflows.py @@ -26,6 +26,7 @@ def infer_px_then_ob_model( px_model: IlastikPixelClassifierModel, ob_model: IlastikObjectClassifierFromPixelPredictionsModel, where_output: Path, + channel: int = None, **kwargs ) -> WorkflowRunRecord: """ @@ -35,6 +36,7 @@ def infer_px_then_ob_model( :param px_model: model instance for pixel classification :param ob_model: model instance for object classification :param where_output: Path object that references output image directory + :param channel: input image channel to pass to pixel classification, or all channels if None :param kwargs: variable-length keyword arguments :return: """ @@ -42,8 +44,12 @@ def infer_px_then_ob_model( assert isinstance(ob_model, IlastikObjectClassifierFromPixelPredictionsModel) ti = Timer() - ch = kwargs.get('channel') - img = generate_file_accessor(fpi).get_one_channel_data(ch, mip=kwargs.get('mip', False)) + raw_acc = generate_file_accessor(fpi) + if channel is not None: + channels = [channel] + else: + channels = range(0, raw_acc.chroma) + img = raw_acc.get_channels(channels, mip=kwargs.get('mip', False)) ti.click('file_input') px_map, _ = px_model.infer(img)