Skip to content
Snippets Groups Projects
transfer_labels_to_ilastik_object_classifier.py 7.58 KiB
import shutil
from pathlib import Path
import h5py
import json
import numpy as np
import pandas as pd
import skimage
import uuid
import vigra

from extensions.chaeo.util import autonumber_new_file
from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
from model_server.accessors import generate_file_accessor, GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file

class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):

    @staticmethod
    def make_tczyx(acc):
        assert acc.chroma == 1
        tyx = np.moveaxis(
            acc.data[:, :, 0, :], # YX(C)Z
            [2, 0, 1],
            [0, 1, 2]
        )
        return np.expand_dims(tyx, (1, 2))
        # return tyx

    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        assert segmentation_img.is_mask()
        assert input_img.chroma == 1

        tagged_input_data = vigra.taggedView(self.make_tczyx(input_img), 'tczyx')
        tagged_seg_data = vigra.taggedView(self.make_tczyx(segmentation_img), 'tczyx')

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'

        # for some reason these axes get scrambled to Z(1)YX(1)
        assert obmaps[0].shape == (input_img.nz, 1, input_img.hw[0], input_img.hw[1], 1)
        yxcz = np.moveaxis(
            obmaps[0][:, :, :, :, 0],
            [2, 3, 1, 0],
            [0, 1, 2, 3]
        )

        assert yxcz.shape == input_img.shape
        return InMemoryDataAccessor(data=yxcz), {'success': True}

def get_dataset_info(h5, lane=0):
    lns = f'{lane:04d}'
    lane = f'Input Data/infos/lane{lns}'
    info = {}
    for gk in ['Raw Data', 'Segmentation Image']:
        info[gk] = {}
        for dk in ['location', 'filePath', 'shape', 'nickname']:
            try:
                info[gk][dk] = h5[f'{lane}/{gk}/{dk}'][()]
            except Exception as e:
                print(e)
        try:
            info[gk]['id'] = uuid.UUID(h5[f'{lane}/{gk}/datasetId'][()].decode())
        except ValueError as e:
            info[gk]['id'] = '<invalid UUID>'
        info[gk]['axistags'] = json.loads(h5[f'{lane}/{gk}/axistags'][()].decode())
        info[gk]['axes'] = [ax['key'] for ax in info[gk]['axistags']['axes']]

    obj_cl_group = h5[f'ObjectClassification/LabelInputs/{lns}']
    info['misc'] = {
        'number_of_label_inputs': len(obj_cl_group.items())
    }
    return info


def generate_ilastik_object_classifier(template_ilp, where: str, lane=0):

    # validate z-stack input data
    root = Path(where)
    paths = {
        'Raw Data': root / 'zstack_train_raw.tif',
        'Segmentation Image': root / 'zstack_train_mask.tif',
    }

    accessors = {k: generate_file_accessor(pa) for k, pa in paths.items()}

    assert accessors['Raw Data'].chroma == 1
    assert accessors['Segmentation Image'].is_mask()
    assert len(set([a.hw for a in accessors.values()])) == 1  # same height and width
    assert len(set([a.nz for a in accessors.values()])) == 1  # same z-depth
    nz = accessors['Raw Data'].nz

    # now load CSV
    csv_path = root / 'train_stack.csv'
    assert csv_path.exists()
    df_patches = pd.read_csv(root / 'train_stack.csv')
    assert np.all(
        df_patches['zi'].sort_values().to_numpy() == np.arange(0, nz)
    )
    df_labels = pd.read_csv(root / 'labels_key.csv')
    label_names = list(df_labels.sort_values('annotation_class_id').annotation_class.unique())
    label_names[0] = 'none'
    assert len(label_names) >= 2

    # open, validate, and copy template project file
    with h5py.File(template_ilp, 'r') as h5:
        info = get_dataset_info(h5)

        for hg in ['Raw Data', 'Segmentation Image']:
            assert info[hg]['location'] == b'FileSystem'
            assert info[hg]['axes'] == ['t', 'y', 'x']

    new_ilp = shutil.copy(template_ilp, root / autonumber_new_file(root, 'auto-obj', 'ilp'))

    # write to new project file
    lns = f'{lane:04d}'
    with h5py.File(new_ilp, 'r+') as h5:
        def set_ds(grp, ds, val):
            ds = h5[f'Input Data/infos/lane{lns}/{grp}/{ds}']
            ds[()] = val
            return ds[()]

        def get_label(idx):
            return df_patches.loc[df_patches.zi == idx, 'annotation_class_id'].iat[0]

        for hg in ['Raw Data', 'Segmentation Image']:
            set_ds(hg, 'filePath', paths[hg].__str__())
            set_ds(hg, 'nickname', paths[hg].stem)
            shape_zyx = [accessors[hg].shape_dict[ax] for ax in ['Z', 'Y', 'X']]
            set_ds(hg, 'shape', np.array(shape_zyx))

        # change key of label names
        del h5['ObjectClassification/LabelNames']
        ln = np.array(label_names)
        h5.create_dataset('ObjectClassification/LabelNames', data=ln.astype('O'))

        # change object labels
        la_groupname = f'ObjectClassification/LabelInputs/{lns}'

        del h5[la_groupname]
        lag = h5.create_group(la_groupname)
        for zi in range(0, nz):
            lag[f'{zi}'] = np.array([0., float(get_label(zi))])

    return new_ilp

def compare_object_maps(truth: GenericImageDataAccessor, inferred: GenericImageDataAccessor) -> pd.DataFrame:
    assert truth.shape == inferred.shape
    assert np.all((truth.data == 0) == (inferred.data == 0))
    assert inferred.chroma == 1

    labels = []
    for zi in range(0, inferred.nz):
        inf_img = inferred.data[:, :, :, zi]

        unique = np.unique(inf_img)
        assert unique[0] == 0

        dd = {'zi': zi, 'truth_label': np.unique(truth.data[:, :, :, zi])[1], 'multiples': False}

        if len(unique) == 1:  # no object in frame
            dd['inferred_label'] = unique[0]
        elif len(unique) > 2:  # multiple objects in frame, so mask out all but largest
            ob_id = skimage.measure.label(inf_img)
            pr = skimage.measure.regionprops_table(ob_id, properties=['label', 'area'])
            mask = inf_img == pr['label'][pr['area'].argmax()]
            dd['inferred_label'] = np.unique(mask * inf_img)[1]
            dd['multiples'] = True
        else:  # exactly one unique object class in frame
            dd['inferred_label'] = unique[1]
        labels.append(dd)
    return pd.DataFrame(labels)

if __name__ == '__main__':
    root =  Path('c:/Users/rhodes/projects/proj0011-plankton-seg/')
    template_ilp = root / 'exp0014/template_obj.ilp'
    # template_ilp = root / 'exp0014/test_obj_from_seg.ilp'
    where_patch_stack = root / 'exp0009/output/labeled_patches-20231016-0002'

    new_ilp = generate_ilastik_object_classifier(
        template_ilp,
        where_patch_stack,
    )

    train_zstack_raw = generate_file_accessor(where_patch_stack / 'zstack_train_raw.tif')
    train_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_train_mask.tif')

    mod = PatchStackObjectClassifier({'project_file': new_ilp})

    result_acc, _ = mod.infer(train_zstack_raw, train_zstack_mask)
    write_accessor_data_to_file(where_patch_stack / 'result.tif', result_acc)
    print(where_patch_stack / 'result.tif')

    # write comparison
    train_labels = generate_file_accessor(where_patch_stack / 'zstack_train_label.tif')
    df_comp = compare_object_maps(train_labels, result_acc)
    df_comp.to_csv(where_patch_stack / autonumber_new_file(where_patch_stack, 'comp', 'csv'), index=False)