Skip to content
Snippets Groups Projects
Commit 98797757 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'dev_multichannel_classifiers' into 'staging'

Dev multichannel classifiers

See merge request !29
parents 74cbb1dd f5e4fc4e
No related branches found
No related tags found
3 merge requests!37Release 2024.04.19,!34Revert "Temporary error-handling for debug...",!29Dev multichannel classifiers
......@@ -85,7 +85,7 @@ class InstanceSegmentationModel(ImageToImageModel):
"""
if not mask.is_mask():
raise InvalidInputImageError('Expecting a binary mask')
if not img.shape == mask.shape:
if img.hw != mask.hw or img.nz != mask.nz:
raise InvalidInputImageError('Expect input image and mask to be the same shape')
def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs):
......
......@@ -53,14 +53,13 @@ class IlastikModel(Model):
shell = app.main(args, init_logging=False)
# validate if inputs are embedded in project file
input_groups = shell.projectManager.currentProjectFile['Input Data']['infos']
lanes = input_groups.keys()
for ll in lanes:
input_types = input_groups[ll]
for tt in input_types:
ds_loc = input_groups[ll][tt].get('location', False)
if self.enforce_embedded and ds_loc and ds_loc[()] == b'FileSystem':
h5 = shell.projectManager.currentProjectFile
for lane in h5['Input Data/infos'].keys():
for role in h5[f'Input Data/infos/{lane}'].keys():
grp = h5[f'Input Data/infos/{lane}/{role}']
if self.enforce_embedded and ('location' in grp.keys()) and grp['location'][()] != b'ProjectInternal':
raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem')
assert True
if not isinstance(shell.workflow, self.get_workflow()):
raise ParameterExpectedError(
f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
......@@ -69,11 +68,6 @@ class IlastikModel(Model):
return True
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@property
def model_shape_dict(self):
raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
......@@ -94,6 +88,11 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
def model_3d(self):
return self.model_shape_dict['Z'] > 1
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@staticmethod
def get_workflow():
from ilastik.workflows import PixelClassificationWorkflow
......@@ -142,6 +141,9 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
assert segmentation_img.is_mask()
if isinstance(input_img, PatchStack):
assert isinstance(segmentation_img, PatchStack)
......@@ -199,6 +201,9 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment