From 2a542f58588b437664720eb2c65946b69820e858 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 26 Apr 2024 15:16:33 +0200 Subject: [PATCH] ilastik now allows multichannel inputs --- model_server/extensions/ilastik/models.py | 24 ++++++++++++++++---- model_server/extensions/ilastik/workflows.py | 10 ++++++-- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 2dea1332..2c92c026 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -104,8 +104,15 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): return [l.decode() for l in h5['PixelClassification/LabelNames'][()]] def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict): - if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): - raise IlastikInputShapeError() + if self.model_chroma != input_img.chroma: + raise IlastikInputShapeError( + f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}' + ) + if self.model_3d != input_img.is_3d(): + if self.model_3d: + raise IlastikInputShapeError(f'Model is 3D but input image is 2D') + else: + raise IlastikInputShapeError(f'Model is 2D but input image is 3D') tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') dsi = [ @@ -164,7 +171,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): if self.model_chroma != input_img.chroma: raise IlastikInputShapeError( - f'Model {self} expects {self.model_chroma} input channels but received only {input_img.chroma}' + f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}' ) if self.model_3d != input_img.is_3d(): if self.model_3d: @@ -229,8 +236,15 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag return ObjectClassificationWorkflowPrediction def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): - if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): - raise IlastikInputShapeError() + if self.model_chroma != input_img.chroma: + raise IlastikInputShapeError( + f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}' + ) + if self.model_3d != input_img.is_3d(): + if self.model_3d: + raise IlastikInputShapeError(f'Model is 3D but input image is 2D') + else: + raise IlastikInputShapeError(f'Model is 2D but input image is 3D') if isinstance(input_img, PatchStack): assert isinstance(pxmap_img, PatchStack) diff --git a/model_server/extensions/ilastik/workflows.py b/model_server/extensions/ilastik/workflows.py index c5b1575f..6f913f65 100644 --- a/model_server/extensions/ilastik/workflows.py +++ b/model_server/extensions/ilastik/workflows.py @@ -26,6 +26,7 @@ def infer_px_then_ob_model( px_model: IlastikPixelClassifierModel, ob_model: IlastikObjectClassifierFromPixelPredictionsModel, where_output: Path, + channel: int = None, **kwargs ) -> WorkflowRunRecord: """ @@ -35,6 +36,7 @@ def infer_px_then_ob_model( :param px_model: model instance for pixel classification :param ob_model: model instance for object classification :param where_output: Path object that references output image directory + :param channel: input image channel to pass to pixel classification, or all channels if None :param kwargs: variable-length keyword arguments :return: """ @@ -42,8 +44,12 @@ def infer_px_then_ob_model( assert isinstance(ob_model, IlastikObjectClassifierFromPixelPredictionsModel) ti = Timer() - ch = kwargs.get('channel') - img = generate_file_accessor(fpi).get_one_channel_data(ch, mip=kwargs.get('mip', False)) + raw_acc = generate_file_accessor(fpi) + if channel is not None: + channels = [channel] + else: + channels = range(0, raw_acc.chroma) + img = raw_acc.get_channels(channels, mip=kwargs.get('mip', False)) ti.click('file_input') px_map, _ = px_model.infer(img) -- GitLab