From f5e4fc4ecad6d5e85432568a7e849608555f8d33 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Tue, 9 Apr 2024 16:17:20 +0200 Subject: [PATCH] Check input against model dimensionality and chroma in all ilastik cases --- model_server/extensions/ilastik/models.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 0cb6370c..4126e14e 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -36,7 +36,6 @@ class IlastikModel(Model): raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file') self.shell = None - self.axes = None super().__init__(autoload, params) def load(self): @@ -65,18 +64,10 @@ class IlastikModel(Model): raise ParameterExpectedError( f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}' ) - self.axes = [ - a['key'].upper() for a in json.loads(h5[f'Input Data/infos/lane0000/Raw Data/axistags'][()])['axes'] - ] self.shell = shell return True - -class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): - model_id = 'ilastik_pixel_classification' - operations = ['segment', ] - @property def model_shape_dict(self): raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data'] @@ -97,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): def model_3d(self): return self.model_shape_dict['Z'] > 1 + +class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): + model_id = 'ilastik_pixel_classification' + operations = ['segment', ] + @staticmethod def get_workflow(): from ilastik.workflows import PixelClassificationWorkflow @@ -145,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): + if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): + raise IlastikInputShapeError() + assert segmentation_img.is_mask() if isinstance(input_img, PatchStack): assert isinstance(segmentation_img, PatchStack) @@ -202,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag return ObjectClassificationWorkflowPrediction def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): + if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): + raise IlastikInputShapeError() + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') -- GitLab