-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
transfer_labels_to_ilastik_object_classifier.py 7.58 KiB
import shutil
from pathlib import Path
import h5py
import json
import numpy as np
import pandas as pd
import skimage
import uuid
import vigra
from extensions.chaeo.util import autonumber_new_file
from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
from model_server.accessors import generate_file_accessor, GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
@staticmethod
def make_tczyx(acc):
assert acc.chroma == 1
tyx = np.moveaxis(
acc.data[:, :, 0, :], # YX(C)Z
[2, 0, 1],
[0, 1, 2]
)
return np.expand_dims(tyx, (1, 2))
# return tyx
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
assert segmentation_img.is_mask()
assert input_img.chroma == 1
tagged_input_data = vigra.taggedView(self.make_tczyx(input_img), 'tczyx')
tagged_seg_data = vigra.taggedView(self.make_tczyx(segmentation_img), 'tczyx')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
# for some reason these axes get scrambled to Z(1)YX(1)
assert obmaps[0].shape == (input_img.nz, 1, input_img.hw[0], input_img.hw[1], 1)
yxcz = np.moveaxis(
obmaps[0][:, :, :, :, 0],
[2, 3, 1, 0],
[0, 1, 2, 3]
)
assert yxcz.shape == input_img.shape
return InMemoryDataAccessor(data=yxcz), {'success': True}
def get_dataset_info(h5, lane=0):
lns = f'{lane:04d}'
lane = f'Input Data/infos/lane{lns}'
info = {}
for gk in ['Raw Data', 'Segmentation Image']:
info[gk] = {}
for dk in ['location', 'filePath', 'shape', 'nickname']:
try:
info[gk][dk] = h5[f'{lane}/{gk}/{dk}'][()]
except Exception as e:
print(e)
try:
info[gk]['id'] = uuid.UUID(h5[f'{lane}/{gk}/datasetId'][()].decode())
except ValueError as e:
info[gk]['id'] = '<invalid UUID>'
info[gk]['axistags'] = json.loads(h5[f'{lane}/{gk}/axistags'][()].decode())
info[gk]['axes'] = [ax['key'] for ax in info[gk]['axistags']['axes']]
obj_cl_group = h5[f'ObjectClassification/LabelInputs/{lns}']
info['misc'] = {
'number_of_label_inputs': len(obj_cl_group.items())
}
return info
def generate_ilastik_object_classifier(template_ilp, where: str, lane=0):
# validate z-stack input data
root = Path(where)
paths = {
'Raw Data': root / 'zstack_train_raw.tif',
'Segmentation Image': root / 'zstack_train_mask.tif',
}
accessors = {k: generate_file_accessor(pa) for k, pa in paths.items()}
assert accessors['Raw Data'].chroma == 1
assert accessors['Segmentation Image'].is_mask()
assert len(set([a.hw for a in accessors.values()])) == 1 # same height and width
assert len(set([a.nz for a in accessors.values()])) == 1 # same z-depth
nz = accessors['Raw Data'].nz
# now load CSV
csv_path = root / 'train_stack.csv'
assert csv_path.exists()
df_patches = pd.read_csv(root / 'train_stack.csv')
assert np.all(
df_patches['zi'].sort_values().to_numpy() == np.arange(0, nz)
)
df_labels = pd.read_csv(root / 'labels_key.csv')
label_names = list(df_labels.sort_values('annotation_class_id').annotation_class.unique())
label_names[0] = 'none'
assert len(label_names) >= 2
# open, validate, and copy template project file
with h5py.File(template_ilp, 'r') as h5:
info = get_dataset_info(h5)
for hg in ['Raw Data', 'Segmentation Image']:
assert info[hg]['location'] == b'FileSystem'
assert info[hg]['axes'] == ['t', 'y', 'x']
new_ilp = shutil.copy(template_ilp, root / autonumber_new_file(root, 'auto-obj', 'ilp'))
# write to new project file
lns = f'{lane:04d}'
with h5py.File(new_ilp, 'r+') as h5:
def set_ds(grp, ds, val):
ds = h5[f'Input Data/infos/lane{lns}/{grp}/{ds}']
ds[()] = val
return ds[()]
def get_label(idx):
return df_patches.loc[df_patches.zi == idx, 'annotation_class_id'].iat[0]
for hg in ['Raw Data', 'Segmentation Image']:
set_ds(hg, 'filePath', paths[hg].__str__())
set_ds(hg, 'nickname', paths[hg].stem)
shape_zyx = [accessors[hg].shape_dict[ax] for ax in ['Z', 'Y', 'X']]
set_ds(hg, 'shape', np.array(shape_zyx))
# change key of label names
del h5['ObjectClassification/LabelNames']
ln = np.array(label_names)
h5.create_dataset('ObjectClassification/LabelNames', data=ln.astype('O'))
# change object labels
la_groupname = f'ObjectClassification/LabelInputs/{lns}'
del h5[la_groupname]
lag = h5.create_group(la_groupname)
for zi in range(0, nz):
lag[f'{zi}'] = np.array([0., float(get_label(zi))])
return new_ilp
def compare_object_maps(truth: GenericImageDataAccessor, inferred: GenericImageDataAccessor) -> pd.DataFrame:
assert truth.shape == inferred.shape
assert np.all((truth.data == 0) == (inferred.data == 0))
assert inferred.chroma == 1
labels = []
for zi in range(0, inferred.nz):
inf_img = inferred.data[:, :, :, zi]
unique = np.unique(inf_img)
assert unique[0] == 0
dd = {'zi': zi, 'truth_label': np.unique(truth.data[:, :, :, zi])[1], 'multiples': False}
if len(unique) == 1: # no object in frame
dd['inferred_label'] = unique[0]
elif len(unique) > 2: # multiple objects in frame, so mask out all but largest
ob_id = skimage.measure.label(inf_img)
pr = skimage.measure.regionprops_table(ob_id, properties=['label', 'area'])
mask = inf_img == pr['label'][pr['area'].argmax()]
dd['inferred_label'] = np.unique(mask * inf_img)[1]
dd['multiples'] = True
else: # exactly one unique object class in frame
dd['inferred_label'] = unique[1]
labels.append(dd)
return pd.DataFrame(labels)
if __name__ == '__main__':
root = Path('c:/Users/rhodes/projects/proj0011-plankton-seg/')
template_ilp = root / 'exp0014/template_obj.ilp'
# template_ilp = root / 'exp0014/test_obj_from_seg.ilp'
where_patch_stack = root / 'exp0009/output/labeled_patches-20231016-0002'
new_ilp = generate_ilastik_object_classifier(
template_ilp,
where_patch_stack,
)
train_zstack_raw = generate_file_accessor(where_patch_stack / 'zstack_train_raw.tif')
train_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_train_mask.tif')
mod = PatchStackObjectClassifier({'project_file': new_ilp})
result_acc, _ = mod.infer(train_zstack_raw, train_zstack_mask)
write_accessor_data_to_file(where_patch_stack / 'result.tif', result_acc)
print(where_patch_stack / 'result.tif')
# write comparison
train_labels = generate_file_accessor(where_patch_stack / 'zstack_train_label.tif')
df_comp = compare_object_maps(train_labels, result_acc)
df_comp.to_csv(where_patch_stack / autonumber_new_file(where_patch_stack, 'comp', 'csv'), index=False)