Skip to content
Snippets Groups Projects
models.py 5.05 KiB
from pathlib import Path
import shutil

import h5py
import numpy as np
import vigra

from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile
from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel


class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
    """
    Wrap ilastik object classification for inputs comprising raw image and binary segmentation masks, both represented
    as time-series images where each frame contains only one object.
    """

    def infer(self, input_acc: MonoPatchStack, segmentation_acc: MonoPatchStack) -> (np.ndarray, dict):
        assert segmentation_acc.is_mask()
        assert input_acc.chroma == 1

        tagged_input_data = vigra.taggedView(input_acc.make_tczyx(), 'tczyx')
        tagged_seg_data = vigra.taggedView(segmentation_acc.make_tczyx(), 'tczyx')

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)  # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'

        # for some reason ilastik scrambles these axes to Z(1)YX(1)
        assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
        yxz = np.moveaxis(
            obmaps[0][:, 0, :, :, 0],
            [1, 2, 0],
            [0, 1, 2]
        )

        assert yxz.shape[0:2] == input_acc.hw
        assert yxz.shape[2] == input_acc.nz
        return MonoPatchStack(data=yxz), {'success': True}


def generate_ilastik_object_classifier(
        template_ilp: Path,
        target_ilp: Path,
        raw_stack: MonoPatchStackFromFile,
        mask_stack: MonoPatchStackFromFile,
        label_stack: MonoPatchStackFromFile,
        label_names: list,
        lane: int = 0,
) -> Path:
    """
    Starting with a template project file, transfer input data and labels to a new project file.
    :param template_ilp: path to existing ilastik object classifier to use as a template
    :param target_ilp: path to new classifier
    :param raw_stack: stack of patches containing raw data
    :param mask_stack: stack of patches containing object masks
    :param label_stack: stack of patches containing object labels
    :param label_names: list of label names
    :param lane: ilastik lane identifier
    :return: path to generated object classifier
    """
    assert mask_stack.shape == raw_stack.shape
    assert label_stack.shape == raw_stack.shape

    new_ilp = shutil.copy(template_ilp, target_ilp)

    accessors = {
        'Raw Data': raw_stack,
        'Segmentation Image': mask_stack,
    }

    # get labels from label image
    labels = []
    for ii in range(0, label_stack.count):
        unique = np.unique(label_stack.iat(ii))
        assert len(unique) >= 2, 'Label image contains more than one non-zero value'
        assert unique[0] == 0, 'Label image does not contain unlabeled background'
        assert unique[-1] < len(label_names) + 1, f'Label ID {unique[-1]} exceeds number of label names: {len(label_names)}'
        labels.append(unique[-1])

    # write to new project file
    with h5py.File(new_ilp, 'r+') as h5:

        for gk in ['Raw Data', 'Segmentation Image']:
            group = f'Input Data/infos/lane{lane:04d}/{gk}'

            # set path to input image files
            del h5[f'{group}/filePath']
            h5[f'{group}/filePath'] = accessors[gk].fpath.name
            assert not Path(h5[f'{group}/filePath'][()].decode()).is_absolute()
            assert h5[f'{group}/filePath'][()] == accessors[gk].fpath.name.encode()
            assert h5[f'{group}/location'][()] == 'FileSystem'.encode()

            # set input nickname
            del h5[f'{group}/nickname']
            h5[f'{group}/nickname'] = accessors[gk].fpath.stem

            # set input shape
            del h5[f'{group}/shape']
            shape_zyx = [accessors[gk].shape_dict[ax] for ax in ['Z', 'Y', 'X']]
            h5[f'{group}/shape'] = np.array(shape_zyx)

        # change key of label names
        if (k := 'ObjectClassification/LabelNames') in h5.keys():
            del h5[k]
        ln = np.array(label_names)
        h5.create_dataset(k, data=ln.astype('O'))

        if (k := 'ObjectClassification/MaxNumObj') in h5.keys():
            del h5[k]
        h5[k] = len(label_names) - 1

        del h5['currentApplet']
        h5['currentApplet'] = 1

        # change object labels
        if (k := f'ObjectClassification/LabelInputs/{lane:04d}') in h5.keys():
            del h5[k]
        lag = h5.create_group(k)
        for zi, la in enumerate(labels):
            lag[f'{zi}'] = np.array([0., float(la)])

        # delete existing classification weights
        if (k := f'ObjectExtraction/RegionFeatures/{lane:04d}') in h5.keys():
            del h5[k]
        if (k := 'ObjectClassification/ClassifierForests') in h5.keys():
            del h5[k]

    return Path(new_ilp)