diff --git a/model_server/base/models.py b/model_server/base/models.py index 4fa0849031bcd318a66f49c8843763b0a99c2279..a8576d40c2a1be0d1f2b44ac65eb0a8354b5243d 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -85,7 +85,7 @@ class InstanceSegmentationModel(ImageToImageModel): """ if not mask.is_mask(): raise InvalidInputImageError('Expecting a binary mask') - if not img.shape == mask.shape: + if img.hw != mask.hw or img.nz != mask.nz: raise InvalidInputImageError('Expect input image and mask to be the same shape') def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs): diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 7229b43eb601dc0703499392403e42e1bce3bf98..4126e14e11e2591fe00c9e21b0430eda0eb7085f 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -53,14 +53,13 @@ class IlastikModel(Model): shell = app.main(args, init_logging=False) # validate if inputs are embedded in project file - input_groups = shell.projectManager.currentProjectFile['Input Data']['infos'] - lanes = input_groups.keys() - for ll in lanes: - input_types = input_groups[ll] - for tt in input_types: - ds_loc = input_groups[ll][tt].get('location', False) - if self.enforce_embedded and ds_loc and ds_loc[()] == b'FileSystem': + h5 = shell.projectManager.currentProjectFile + for lane in h5['Input Data/infos'].keys(): + for role in h5[f'Input Data/infos/{lane}'].keys(): + grp = h5[f'Input Data/infos/{lane}/{role}'] + if self.enforce_embedded and ('location' in grp.keys()) and grp['location'][()] != b'ProjectInternal': raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem') + assert True if not isinstance(shell.workflow, self.get_workflow()): raise ParameterExpectedError( f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}' @@ -69,11 +68,6 @@ class IlastikModel(Model): return True - -class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): - model_id = 'ilastik_pixel_classification' - operations = ['segment', ] - @property def model_shape_dict(self): raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data'] @@ -94,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): def model_3d(self): return self.model_shape_dict['Z'] > 1 + +class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): + model_id = 'ilastik_pixel_classification' + operations = ['segment', ] + @staticmethod def get_workflow(): from ilastik.workflows import PixelClassificationWorkflow @@ -142,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment return ObjectClassificationWorkflowBinary def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): + if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): + raise IlastikInputShapeError() + assert segmentation_img.is_mask() if isinstance(input_img, PatchStack): assert isinstance(segmentation_img, PatchStack) @@ -199,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag return ObjectClassificationWorkflowPrediction def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): + if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): + raise IlastikInputShapeError() + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')