Skip to content
Snippets Groups Projects
Commit f5e4fc4e authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Check input against model dimensionality and chroma in all ilastik cases

parent 2dfd9037
No related branches found
No related tags found
3 merge requests!37Release 2024.04.19,!34Revert "Temporary error-handling for debug...",!29Dev multichannel classifiers
......@@ -36,7 +36,6 @@ class IlastikModel(Model):
raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
self.shell = None
self.axes = None
super().__init__(autoload, params)
def load(self):
......@@ -65,18 +64,10 @@ class IlastikModel(Model):
raise ParameterExpectedError(
f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
)
self.axes = [
a['key'].upper() for a in json.loads(h5[f'Input Data/infos/lane0000/Raw Data/axistags'][()])['axes']
]
self.shell = shell
return True
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@property
def model_shape_dict(self):
raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
......@@ -97,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
def model_3d(self):
return self.model_shape_dict['Z'] > 1
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@staticmethod
def get_workflow():
from ilastik.workflows import PixelClassificationWorkflow
......@@ -145,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
assert segmentation_img.is_mask()
if isinstance(input_img, PatchStack):
assert isinstance(segmentation_img, PatchStack)
......@@ -202,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment