Skip to content
Snippets Groups Projects
Commit 2a542f58 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

ilastik now allows multichannel inputs

parent b9037651
No related branches found
No related tags found
No related merge requests found
......@@ -104,8 +104,15 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
return [l.decode() for l in h5['PixelClassification/LabelNames'][()]]
def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
if self.model_chroma != input_img.chroma:
raise IlastikInputShapeError(
f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
)
if self.model_3d != input_img.is_3d():
if self.model_3d:
raise IlastikInputShapeError(f'Model is 3D but input image is 2D')
else:
raise IlastikInputShapeError(f'Model is 2D but input image is 3D')
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
dsi = [
......@@ -164,7 +171,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma:
raise IlastikInputShapeError(
f'Model {self} expects {self.model_chroma} input channels but received only {input_img.chroma}'
f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
)
if self.model_3d != input_img.is_3d():
if self.model_3d:
......@@ -229,8 +236,15 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
if self.model_chroma != input_img.chroma:
raise IlastikInputShapeError(
f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
)
if self.model_3d != input_img.is_3d():
if self.model_3d:
raise IlastikInputShapeError(f'Model is 3D but input image is 2D')
else:
raise IlastikInputShapeError(f'Model is 2D but input image is 3D')
if isinstance(input_img, PatchStack):
assert isinstance(pxmap_img, PatchStack)
......
......@@ -26,6 +26,7 @@ def infer_px_then_ob_model(
px_model: IlastikPixelClassifierModel,
ob_model: IlastikObjectClassifierFromPixelPredictionsModel,
where_output: Path,
channel: int = None,
**kwargs
) -> WorkflowRunRecord:
"""
......@@ -35,6 +36,7 @@ def infer_px_then_ob_model(
:param px_model: model instance for pixel classification
:param ob_model: model instance for object classification
:param where_output: Path object that references output image directory
:param channel: input image channel to pass to pixel classification, or all channels if None
:param kwargs: variable-length keyword arguments
:return:
"""
......@@ -42,8 +44,12 @@ def infer_px_then_ob_model(
assert isinstance(ob_model, IlastikObjectClassifierFromPixelPredictionsModel)
ti = Timer()
ch = kwargs.get('channel')
img = generate_file_accessor(fpi).get_one_channel_data(ch, mip=kwargs.get('mip', False))
raw_acc = generate_file_accessor(fpi)
if channel is not None:
channels = [channel]
else:
channels = range(0, raw_acc.chroma)
img = raw_acc.get_channels(channels, mip=kwargs.get('mip', False))
ti.click('file_input')
px_map, _ = px_model.infer(img)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment