Skip to content
Snippets Groups Projects
transfer_labels_to_ilastik_object_classifier.py 10.64 KiB
import shutil
from pathlib import Path
import h5py
import json
import numpy as np
import pandas as pd
import skimage
import uuid
import vigra

from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
from model_server.accessors import generate_file_accessor, GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file

class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
    """
    Wrap ilastik object classification for inputs comprising raw image and binary segmentation masks, both represented
    as time-series images where each frame contains only one object.
    """

    @staticmethod
    def make_tczyx(acc: GenericImageDataAccessor):
        assert acc.chroma == 1
        tyx = np.moveaxis(
            acc.data[:, :, 0, :], # YX(C)Z
            [2, 0, 1],
            [0, 1, 2]
        )
        return np.expand_dims(tyx, (1, 2))

    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        assert segmentation_img.is_mask()
        assert input_img.chroma == 1

        tagged_input_data = vigra.taggedView(self.make_tczyx(input_img), 'tczyx')
        tagged_seg_data = vigra.taggedView(self.make_tczyx(segmentation_img), 'tczyx')

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)  # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'

        # for some reason ilastik scrambles these axes to Z(1)YX(1)
        assert obmaps[0].shape == (input_img.nz, 1, input_img.hw[0], input_img.hw[1], 1)
        yxcz = np.moveaxis(
            obmaps[0][:, :, :, :, 0],
            [2, 3, 1, 0],
            [0, 1, 2, 3]
        )

        assert yxcz.shape == input_img.shape
        return InMemoryDataAccessor(data=yxcz), {'success': True}

def get_dataset_info(h5: h5py.File, lane : int = 0):
    """
    Report out specific datasets in ilastik project file HDF5
    :param h5: handle to ilastik project file, as h5py.File object
    :param lane: ilastik lane identifier
    :return: (dict) selected data values from project file
    """
    lns = f'{lane:04d}'
    lane = f'Input Data/infos/lane{lns}'
    info = {}
    for gk in ['Raw Data', 'Segmentation Image']:
        info[gk] = {}
        for dk in ['location', 'filePath', 'shape', 'nickname']:
            try:
                info[gk][dk] = h5[f'{lane}/{gk}/{dk}'][()]
            except Exception as e:
                print(e)
        try:
            info[gk]['id'] = uuid.UUID(h5[f'{lane}/{gk}/datasetId'][()].decode())
        except ValueError as e:
            info[gk]['id'] = '<invalid UUID>'
        info[gk]['axistags'] = json.loads(h5[f'{lane}/{gk}/axistags'][()].decode())
        info[gk]['axes'] = [ax['key'] for ax in info[gk]['axistags']['axes']]

    obj_cl_group = h5[f'ObjectClassification/LabelInputs/{lns}']
    info['misc'] = {
        'number_of_label_inputs': len(obj_cl_group.items())
    }
    return info


def generate_ilastik_object_classifier(
        template_ilp: str,
        where: str,
        stack_name: str = 'train',
        lane: int = 0,
        proj_name='auto_obj'
):
    """
    Starting with a template project file, transfer input data and labels to a duplicate project file.

    :param template_ilp: absolute path to existing ilastik object classifier to use as a template
    :param where: absolute path to folder containing input data, segmentation maps, labels, and label descriptions
    :poram stack_name: prefix of .tif and .csv files that contain classifier training data (e.g. train, test)
    :param lane: ilastik lane identifier
    :return: (str) name of new ilastik classifier project file
    """

    # validate z-stack input data
    root = Path(where)
    rel_paths = {
        'Raw Data': Path(f'zstack_{stack_name}_raw.tif'),
        'Segmentation Image': Path(f'zstack_{stack_name}_mask.tif'),
    }

    accessors = {k: generate_file_accessor(root / pa) for k, pa in rel_paths.items()}

    assert accessors['Raw Data'].chroma == 1
    assert accessors['Segmentation Image'].is_mask()
    assert len(set([a.hw for a in accessors.values()])) == 1  # same height and width
    assert len(set([a.nz for a in accessors.values()])) == 1  # same z-depth
    nz = accessors['Raw Data'].nz

    # now load CSV
    csv_path = root / f'{stack_name}_stack.csv'
    assert csv_path.exists()
    df_patches = pd.read_csv(csv_path)
    assert np.all(
        df_patches['zi'].sort_values().to_numpy() == np.arange(0, nz)
    )
    df_labels = pd.read_csv(root / 'labels_key.csv')
    label_names = list(df_labels.sort_values('annotation_class_id').annotation_class.unique())
    label_names[0] = 'none'
    assert len(label_names) >= 2

    # open, validate, and copy template project file
    with h5py.File(template_ilp, 'r') as h5:
        info = get_dataset_info(h5)

        for hg in ['Raw Data', 'Segmentation Image']:
            assert info[hg]['location'] == b'FileSystem'
            assert info[hg]['axes'] == ['t', 'y', 'x']

    new_ilp = shutil.copy(template_ilp, where / (proj_name + '.ilp'))

    # write to new project file
    lns = f'{lane:04d}'
    with h5py.File(where / new_ilp, 'r+') as h5:
        def set_ds(grp, ds, val):
            ds = h5[f'Input Data/infos/lane{lns}/{grp}/{ds}']
            ds[()] = val
            return ds[()]

        def get_label(idx):
            return df_patches.loc[df_patches.zi == idx, 'annotation_class_id'].iat[0]

        for hg in ['Raw Data', 'Segmentation Image']:
            set_ds(hg, 'filePath', rel_paths[hg].__str__())
            set_ds(hg, 'nickname', rel_paths[hg].stem)
            shape_zyx = [accessors[hg].shape_dict[ax] for ax in ['Z', 'Y', 'X']]
            set_ds(hg, 'shape', np.array(shape_zyx))

        # change key of label names
        if (k_ln := 'ObjectClassification/LabelNames') in h5.keys():
            del h5[k_ln]
        ln = np.array(label_names)
        h5.create_dataset(k_ln, data=ln.astype('O'))

        if (k_mn := 'ObjectClassification/MaxNumObj') in h5.keys():
            del h5[k_mn]
        # h5.create_dataset(k_mn, data=(len(label_names) - 1))
        h5[k_mn] = len(label_names) - 1

        del h5['currentApplet']
        h5['currentApplet'] = 1

        # change object labels
        if (k_li := f'ObjectClassification/LabelInputs/{lns}') in h5.keys():
            del h5[k_li]
        lag = h5.create_group(k_li)
        for zi in range(0, nz):
            lag[f'{zi}'] = np.array([0., float(get_label(zi))])

        # delete existing classification weights
        if (k_rf := f'ObjectExtraction/RegionFeatures/{lane:04d}') in h5.keys():
            del h5[k_rf]
        if (k_cf := 'ObjectClassification/ClassificationForests') in h5.keys():
            del h5[k_cf]


    return new_ilp

def compare_object_maps(truth: GenericImageDataAccessor, inferred: GenericImageDataAccessor) -> pd.DataFrame:
    """
    Compare two object maps to assess classification results
    :param truth: t-stack of truth objects
    :param inferred: t-stack of inferred objects, presumably with same segmentation boundaries as truth
    :return: DataFrame comparing results for each frame in truth and inferred stacks
    """
    assert truth.shape == inferred.shape
    assert np.all((truth.data == 0) == (inferred.data == 0))
    assert inferred.chroma == 1

    labels = []
    for zi in range(0, inferred.nz):
        inf_img = inferred.data[:, :, :, zi]

        unique = np.unique(inf_img)
        assert unique[0] == 0

        dd = {'zi': zi, 'truth_label': np.unique(truth.data[:, :, :, zi])[1], 'multiples': False}

        if len(unique) == 1:  # no object in frame
            dd['inferred_label'] = unique[0]
        elif len(unique) > 2:  # multiple objects in frame, so mask out all but largest
            ob_id = skimage.measure.label(inf_img)
            pr = skimage.measure.regionprops_table(ob_id, properties=['label', 'area'])
            mask = inf_img == pr['label'][pr['area'].argmax()]
            dd['inferred_label'] = np.unique(mask * inf_img)[-1]  # occasionally no object in frame
            dd['multiples'] = True
        else:  # exactly one unique object class in frame
            dd['inferred_label'] = unique[1]
        labels.append(dd)
    return pd.DataFrame(labels)

if __name__ == '__main__':
    root =  Path('c:/Users/rhodes/projects/proj0011-plankton-seg/')
    template_ilp = root / 'exp0014/template_obj.ilp'
    where_patch_stack = root / 'exp0009/output/labeled_patches-20231018-0006'

    # auto-populate an object classifier
    auto_ilp = generate_ilastik_object_classifier(
        template_ilp,
        where_patch_stack,
        stack_name='train',
        proj_name='auto_obj_before'
    )


    def infer_and_compare(ilp, suffix):
        # infer object labels from the same data used to train the classifier
        train_zstack_raw = generate_file_accessor(where_patch_stack / 'zstack_train_raw.tif')
        train_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_train_mask.tif')
        mod = PatchStackObjectClassifier({'project_file': where_patch_stack / ilp})
        result_acc, _ = mod.infer(train_zstack_raw, train_zstack_mask)
        write_accessor_data_to_file(where_patch_stack / f'zstack_train_result_{suffix}.tif', result_acc)

        # write comparison tables
        train_truth_labels = generate_file_accessor(where_patch_stack / f'zstack_train_label.tif')
        df_comp = compare_object_maps(train_truth_labels, result_acc)
        df_comp.to_csv(where_patch_stack / f'compare_train_result_{suffix}.csv', index=False)
        print(f'Generated ilastik project {ilp}')
        print('Truth and inferred labels match?')
        print(pd.value_counts(df_comp['truth_label'] == df_comp['inferred_label']))

        # report out some things for debugging
        rks = [
            'ObjectClassification/MaxNumObj',
            'ObjectClassification/LabelInputs/0000/0',
            'currentApplet'
        ]
        with h5py.File(ilp, 'r') as h5:
            for r in rks:
                print(f'{r}: {h5[r][()]}')

    # infer object labels from the same data used to train the classifier
    infer_and_compare(auto_ilp, 'before')

    # copy project and prompt user input once ilastik file has been modified in-app
    mod_ilp = shutil.copy(auto_ilp, where_patch_stack / 'auto_obj_after.ilp')
    print(f'Press enter when project file {mod_ilp} has been updated in ilastik')
    input()

    # repeat inference with the modified project file
    infer_and_compare(mod_ilp, 'after')