Skip to content
Snippets Groups Projects
Commit 98797757 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'dev_multichannel_classifiers' into 'staging'

Dev multichannel classifiers

See merge request !29
parents 74cbb1dd f5e4fc4e
No related branches found
No related tags found
3 merge requests!37Release 2024.04.19,!34Revert "Temporary error-handling for debug...",!29Dev multichannel classifiers
...@@ -85,7 +85,7 @@ class InstanceSegmentationModel(ImageToImageModel): ...@@ -85,7 +85,7 @@ class InstanceSegmentationModel(ImageToImageModel):
""" """
if not mask.is_mask(): if not mask.is_mask():
raise InvalidInputImageError('Expecting a binary mask') raise InvalidInputImageError('Expecting a binary mask')
if not img.shape == mask.shape: if img.hw != mask.hw or img.nz != mask.nz:
raise InvalidInputImageError('Expect input image and mask to be the same shape') raise InvalidInputImageError('Expect input image and mask to be the same shape')
def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs): def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs):
......
...@@ -53,14 +53,13 @@ class IlastikModel(Model): ...@@ -53,14 +53,13 @@ class IlastikModel(Model):
shell = app.main(args, init_logging=False) shell = app.main(args, init_logging=False)
# validate if inputs are embedded in project file # validate if inputs are embedded in project file
input_groups = shell.projectManager.currentProjectFile['Input Data']['infos'] h5 = shell.projectManager.currentProjectFile
lanes = input_groups.keys() for lane in h5['Input Data/infos'].keys():
for ll in lanes: for role in h5[f'Input Data/infos/{lane}'].keys():
input_types = input_groups[ll] grp = h5[f'Input Data/infos/{lane}/{role}']
for tt in input_types: if self.enforce_embedded and ('location' in grp.keys()) and grp['location'][()] != b'ProjectInternal':
ds_loc = input_groups[ll][tt].get('location', False)
if self.enforce_embedded and ds_loc and ds_loc[()] == b'FileSystem':
raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem') raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem')
assert True
if not isinstance(shell.workflow, self.get_workflow()): if not isinstance(shell.workflow, self.get_workflow()):
raise ParameterExpectedError( raise ParameterExpectedError(
f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}' f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
...@@ -69,11 +68,6 @@ class IlastikModel(Model): ...@@ -69,11 +68,6 @@ class IlastikModel(Model):
return True return True
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@property @property
def model_shape_dict(self): def model_shape_dict(self):
raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data'] raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
...@@ -94,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ...@@ -94,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
def model_3d(self): def model_3d(self):
return self.model_shape_dict['Z'] > 1 return self.model_shape_dict['Z'] > 1
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@staticmethod @staticmethod
def get_workflow(): def get_workflow():
from ilastik.workflows import PixelClassificationWorkflow from ilastik.workflows import PixelClassificationWorkflow
...@@ -142,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment ...@@ -142,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
return ObjectClassificationWorkflowBinary return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict): def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
assert segmentation_img.is_mask() assert segmentation_img.is_mask()
if isinstance(input_img, PatchStack): if isinstance(input_img, PatchStack):
assert isinstance(segmentation_img, PatchStack) assert isinstance(segmentation_img, PatchStack)
...@@ -199,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag ...@@ -199,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
return ObjectClassificationWorkflowPrediction return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict): def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz') tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment