models.py 5.05 KiB
from pathlib import Path
import shutil
import h5py
import numpy as np
import vigra
from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile
from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
Wrap ilastik object classification for inputs comprising raw image and binary segmentation masks, both represented
as time-series images where each frame contains only one object.
def infer(self, input_acc: MonoPatchStack, segmentation_acc: MonoPatchStack) -> (np.ndarray, dict):
assert segmentation_acc.is_mask()
assert input_acc.chroma == 1
tagged_input_data = vigra.taggedView(input_acc.make_tczyx(), 'tczyx')
tagged_seg_data = vigra.taggedView(segmentation_acc.make_tczyx(), 'tczyx')
dsi = [
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
# for some reason ilastik scrambles these axes to Z(1)YX(1)
assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
yxz = np.moveaxis(
obmaps[0][:, 0, :, :, 0],
[1, 2, 0],
[0, 1, 2]
assert yxz.shape[0:2] == input_acc.hw
assert yxz.shape[2] == input_acc.nz
return MonoPatchStack(data=yxz), {'success': True}
def generate_ilastik_object_classifier(
template_ilp: Path,
target_ilp: Path,
raw_stack: MonoPatchStackFromFile,
mask_stack: MonoPatchStackFromFile,
label_stack: MonoPatchStackFromFile,
label_names: list,
lane: int = 0,
) -> Path:
Starting with a template project file, transfer input data and labels to a new project file.
:param template_ilp: path to existing ilastik object classifier to use as a template
:param target_ilp: path to new classifier
:param raw_stack: stack of patches containing raw data
:param mask_stack: stack of patches containing object masks
:param label_stack: stack of patches containing object labels
:param label_names: list of label names
:param lane: ilastik lane identifier
:return: path to generated object classifier
assert mask_stack.shape == raw_stack.shape
assert label_stack.shape == raw_stack.shape
new_ilp = shutil.copy(template_ilp, target_ilp)
accessors = {
'Raw Data': raw_stack,
'Segmentation Image': mask_stack,
# get labels from label image
labels = []
for ii in range(0, label_stack.count):
unique = np.unique(label_stack.iat(ii))
assert len(unique) >= 2, 'Label image contains more than one non-zero value'
assert unique[0] == 0, 'Label image does not contain unlabeled background'
assert unique[-1] < len(label_names) + 1, f'Label ID {unique[-1]} exceeds number of label names: {len(label_names)}'
# write to new project file
with h5py.File(new_ilp, 'r+') as h5:
for gk in ['Raw Data', 'Segmentation Image']:
group = f'Input Data/infos/lane{lane:04d}/{gk}'
# set path to input image files
del h5[f'{group}/filePath']
h5[f'{group}/filePath'] = accessors[gk].fpath.name
assert not Path(h5[f'{group}/filePath'][()].decode()).is_absolute()
assert h5[f'{group}/filePath'][()] == accessors[gk].fpath.name.encode()
assert h5[f'{group}/location'][()] == 'FileSystem'.encode()
# set input nickname
del h5[f'{group}/nickname']
h5[f'{group}/nickname'] = accessors[gk].fpath.stem
# set input shape
del h5[f'{group}/shape']
shape_zyx = [accessors[gk].shape_dict[ax] for ax in ['Z', 'Y', 'X']]
h5[f'{group}/shape'] = np.array(shape_zyx)
# change key of label names
if (k := 'ObjectClassification/LabelNames') in h5.keys():
del h5[k]
ln = np.array(label_names)
h5.create_dataset(k, data=ln.astype('O'))
if (k := 'ObjectClassification/MaxNumObj') in h5.keys():
del h5[k]
h5[k] = len(label_names) - 1
del h5['currentApplet']
h5['currentApplet'] = 1
# change object labels
if (k := f'ObjectClassification/LabelInputs/{lane:04d}') in h5.keys():
del h5[k]
lag = h5.create_group(k)
for zi, la in enumerate(labels):
lag[f'{zi}'] = np.array([0., float(la)])
# delete existing classification weights
if (k := f'ObjectExtraction/RegionFeatures/{lane:04d}') in h5.keys():
del h5[k]
if (k := 'ObjectClassification/ClassifierForests') in h5.keys():
del h5[k]
return Path(new_ilp)