-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
models.py 5.05 KiB
from pathlib import Path
import shutil
import h5py
import numpy as np
import vigra
from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile
from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
"""
Wrap ilastik object classification for inputs comprising raw image and binary segmentation masks, both represented
as time-series images where each frame contains only one object.
"""
def infer(self, input_acc: MonoPatchStack, segmentation_acc: MonoPatchStack) -> (np.ndarray, dict):
assert segmentation_acc.is_mask()
assert input_acc.chroma == 1
tagged_input_data = vigra.taggedView(input_acc.make_tczyx(), 'tczyx')
tagged_seg_data = vigra.taggedView(segmentation_acc.make_tczyx(), 'tczyx')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
# for some reason ilastik scrambles these axes to Z(1)YX(1)
assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
yxz = np.moveaxis(
obmaps[0][:, 0, :, :, 0],
[1, 2, 0],
[0, 1, 2]
)
assert yxz.shape[0:2] == input_acc.hw
assert yxz.shape[2] == input_acc.nz
return MonoPatchStack(data=yxz), {'success': True}
def generate_ilastik_object_classifier(
template_ilp: Path,
target_ilp: Path,
raw_stack: MonoPatchStackFromFile,
mask_stack: MonoPatchStackFromFile,
label_stack: MonoPatchStackFromFile,
label_names: list,
lane: int = 0,
) -> Path:
"""
Starting with a template project file, transfer input data and labels to a new project file.
:param template_ilp: path to existing ilastik object classifier to use as a template
:param target_ilp: path to new classifier
:param raw_stack: stack of patches containing raw data
:param mask_stack: stack of patches containing object masks
:param label_stack: stack of patches containing object labels
:param label_names: list of label names
:param lane: ilastik lane identifier
:return: path to generated object classifier
"""
assert mask_stack.shape == raw_stack.shape
assert label_stack.shape == raw_stack.shape
new_ilp = shutil.copy(template_ilp, target_ilp)
accessors = {
'Raw Data': raw_stack,
'Segmentation Image': mask_stack,
}
# get labels from label image
labels = []
for ii in range(0, label_stack.count):
unique = np.unique(label_stack.iat(ii))
assert len(unique) >= 2, 'Label image contains more than one non-zero value'
assert unique[0] == 0, 'Label image does not contain unlabeled background'
assert unique[-1] < len(label_names) + 1, f'Label ID {unique[-1]} exceeds number of label names: {len(label_names)}'
labels.append(unique[-1])
# write to new project file
with h5py.File(new_ilp, 'r+') as h5:
for gk in ['Raw Data', 'Segmentation Image']:
group = f'Input Data/infos/lane{lane:04d}/{gk}'
# set path to input image files
del h5[f'{group}/filePath']
h5[f'{group}/filePath'] = accessors[gk].fpath.name
assert not Path(h5[f'{group}/filePath'][()].decode()).is_absolute()
assert h5[f'{group}/filePath'][()] == accessors[gk].fpath.name.encode()
assert h5[f'{group}/location'][()] == 'FileSystem'.encode()
# set input nickname
del h5[f'{group}/nickname']
h5[f'{group}/nickname'] = accessors[gk].fpath.stem
# set input shape
del h5[f'{group}/shape']
shape_zyx = [accessors[gk].shape_dict[ax] for ax in ['Z', 'Y', 'X']]
h5[f'{group}/shape'] = np.array(shape_zyx)
# change key of label names
if (k := 'ObjectClassification/LabelNames') in h5.keys():
del h5[k]
ln = np.array(label_names)
h5.create_dataset(k, data=ln.astype('O'))
if (k := 'ObjectClassification/MaxNumObj') in h5.keys():
del h5[k]
h5[k] = len(label_names) - 1
del h5['currentApplet']
h5['currentApplet'] = 1
# change object labels
if (k := f'ObjectClassification/LabelInputs/{lane:04d}') in h5.keys():
del h5[k]
lag = h5.create_group(k)
for zi, la in enumerate(labels):
lag[f'{zi}'] = np.array([0., float(la)])
# delete existing classification weights
if (k := f'ObjectExtraction/RegionFeatures/{lane:04d}') in h5.keys():
del h5[k]
if (k := 'ObjectClassification/ClassifierForests') in h5.keys():
del h5[k]
return Path(new_ilp)