Skip to content
Snippets Groups Projects
Commit f5e4fc4e authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Check input against model dimensionality and chroma in all ilastik cases

parent 2dfd9037
No related branches found
No related tags found
3 merge requests!37Release 2024.04.19,!34Revert "Temporary error-handling for debug...",!29Dev multichannel classifiers
...@@ -36,7 +36,6 @@ class IlastikModel(Model): ...@@ -36,7 +36,6 @@ class IlastikModel(Model):
raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file') raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
self.shell = None self.shell = None
self.axes = None
super().__init__(autoload, params) super().__init__(autoload, params)
def load(self): def load(self):
...@@ -65,18 +64,10 @@ class IlastikModel(Model): ...@@ -65,18 +64,10 @@ class IlastikModel(Model):
raise ParameterExpectedError( raise ParameterExpectedError(
f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}' f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
) )
self.axes = [
a['key'].upper() for a in json.loads(h5[f'Input Data/infos/lane0000/Raw Data/axistags'][()])['axes']
]
self.shell = shell self.shell = shell
return True return True
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@property @property
def model_shape_dict(self): def model_shape_dict(self):
raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data'] raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
...@@ -97,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ...@@ -97,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
def model_3d(self): def model_3d(self):
return self.model_shape_dict['Z'] > 1 return self.model_shape_dict['Z'] > 1
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@staticmethod @staticmethod
def get_workflow(): def get_workflow():
from ilastik.workflows import PixelClassificationWorkflow from ilastik.workflows import PixelClassificationWorkflow
...@@ -145,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment ...@@ -145,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
return ObjectClassificationWorkflowBinary return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
assert segmentation_img.is_mask() assert segmentation_img.is_mask()
if isinstance(input_img, PatchStack): if isinstance(input_img, PatchStack):
assert isinstance(segmentation_img, PatchStack) assert isinstance(segmentation_img, PatchStack)
...@@ -202,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag ...@@ -202,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
return ObjectClassificationWorkflowPrediction return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment