Skip to content
Snippets Groups Projects
models.py 11.63 KiB
import json
import os
from pathlib import Path

import numpy as np
import vigra

import model_server.extensions.ilastik.conf
from model_server.base.accessors import PatchStack
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel


class IlastikModel(Model):

    def __init__(self, params, autoload=True, enforce_embedded=True):
        """
        Base class for models that run via ilastik shell API
        :param params:
            project_file: path to ilastik project file
        :param autoload: automatically load model into memory if true
        :param enforce_embedded:
            raise an error if all input data are not embedded in the project file, i.e. on the filesystem
        """
        self.project_file = Path(params['project_file'])
        self.enforce_embedded = enforce_embedded
        params['project_file'] = self.project_file.__str__()
        if self.project_file.is_absolute():
            pap = self.project_file
        else:
            pap = model_server.extensions.ilastik.conf.paths['project_files'] / self.project_file
        self.project_file_abspath = pap
        if not pap.exists():
            raise FileNotFoundError(f'Project file does not exist: {pap}')
        if 'project_file' not in params or not self.project_file_abspath.exists():
            raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')

        self.shell = None
        super().__init__(autoload, params)

    def load(self):
        from ilastik import app
        from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo

        self.PreloadedArrayDatasetInfo = PreloadedArrayDatasetInfo

        os.environ["LAZYFLOW_THREADS"] = "8"
        os.environ["LAZYFLOW_TOTAL_RAM_MB"] = "24000"

        args = app.parse_args([])
        args.headless = True
        args.project = self.project_file_abspath.__str__()
        shell = app.main(args, init_logging=False)

        # validate if inputs are embedded in project file
        h5 = shell.projectManager.currentProjectFile
        for lane in h5['Input Data/infos'].keys():
            for role in h5[f'Input Data/infos/{lane}'].keys():
                grp = h5[f'Input Data/infos/{lane}/{role}']
                if self.enforce_embedded and ('location' in grp.keys()) and grp['location'][()] != b'ProjectInternal':
                    raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem')
            assert True
        if not isinstance(shell.workflow, self.get_workflow()):
            raise ParameterExpectedError(
                f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
            )
        self.shell = shell

        return True

    @property
    def model_shape_dict(self):
        raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
        ax = raw_info['axistags'][()]
        ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']]
        shape = raw_info['shape'][()]
        dd = dict(zip(ax_keys, shape))
        for ci in 'TCZ':
            if ci not in dd.keys():
                dd[ci] = 1
        return dd

    @property
    def model_chroma(self):
        return self.model_shape_dict['C']

    @property
    def model_3d(self):
        return self.model_shape_dict['Z'] > 1


class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
    model_id = 'ilastik_pixel_classification'
    operations = ['segment', ]

    @staticmethod
    def get_workflow():
        from ilastik.workflows import PixelClassificationWorkflow
        return PixelClassificationWorkflow

    @property
    def labels(self):
        h5 = self.shell.projectManager.currentProjectFile
        return [l.decode() for l in h5['PixelClassification/LabelNames'][()]]

    def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict):
        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
            raise IlastikInputShapeError()

        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
            }
        ]
        pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]

        assert len(pxmaps) == 1, 'ilastik generated more than one pixel map'

        yxcz = np.moveaxis(
            pxmaps[0],
            [1, 2, 3, 0],
            [0, 1, 2, 3]
        )
        return InMemoryDataAccessor(data=yxcz), {'success': True}

    def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict):
        """
        Iterative over a patch stack, call inference separately on each cropped patch
        """
        nc = len(self.labels)
        data = np.zeros((img.count, *img.hw, nc, img.nz), dtype=float)  # interpret as PYXCZ
        for i in range(0, img.count):
            sl = img.get_slice_at(i)
            data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True))[0].data
        return PatchStack(data), {'success': True}

    def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs):
        pxmap, _ = self.infer(img)
        mask = pxmap.data[:, :, px_class, :] > px_prob_threshold
        return InMemoryDataAccessor(mask)


class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
    model_id = 'ilastik_object_classification_from_segmentation'

    @staticmethod
    def _make_8bit_mask(nda):
        if nda.dtype == 'bool':
            return 255 * nda.astype('uint8')
        else:
            return nda

    @staticmethod
    def get_workflow():
        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
        return ObjectClassificationWorkflowBinary

    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
            raise IlastikInputShapeError()

        assert segmentation_img.is_mask()
        if isinstance(input_img, PatchStack):
            assert isinstance(segmentation_img, PatchStack)
            tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx')
            tagged_seg_data = vigra.taggedView(
                self._make_8bit_mask(segmentation_img.pczyx),
                'tczyx'
            )
        else:
            tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
            tagged_seg_data = vigra.taggedView(
                self._make_8bit_mask(segmentation_img.data),
                'yxcz'
            )

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'


        if isinstance(input_img, PatchStack):
            pyxcz = np.moveaxis(
                obmaps[0],
                [0, 1, 2, 3, 4],
                [0, 4, 1, 2, 3]
            )
            return PatchStack(data=pyxcz), {'success': True}
        else:
            yxcz = np.moveaxis(
                obmaps[0],
                [1, 2, 3, 0],
                [0, 1, 2, 3]
            )
            return InMemoryDataAccessor(data=yxcz), {'success': True}

    def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
        super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
        obmap, _ = self.infer(img, mask)
        return obmap


class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel):
    model_id = 'ilastik_object_classification_from_pixel_predictions'

    @staticmethod
    def get_workflow():
        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
        return ObjectClassificationWorkflowPrediction

    def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
            raise IlastikInputShapeError()

        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
        tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'

        yxcz = np.moveaxis(
            obmaps[0],
            [1, 2, 3, 0],
            [0, 1, 2, 3]
        )
        return InMemoryDataAccessor(data=yxcz), {'success': True}


    def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs):
        """
        Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is
        assigned a class.
        :param img: input image
        :param pxmap: map of pixel probabilities
        :param kwargs:
            pixel_classification_channel: channel of pxmap used to segment objects
            pixel_classification_thresold: threshold of pxmap used to segment objects
        :return:
        """
        if not img.shape == pxmap.shape:
            raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
        pxch = kwargs.get('pixel_classification_channel', 0)
        pxtr = kwargs.get('pixel_classification_threshold', 0.5)
        mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
        obmap, _ = self.infer(img, mask)
        return obmap

    def make_instance_segmentation_model(self, px_ch: int):
        """
        Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a
        second input.
        :param px_ch: channel of pixel probability map to use
        :return:
            InstanceSegmentationModel object
        """
        class _Mod(self.__class__, InstanceSegmentationModel):
            def label_instance_class(
                    self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
            ) -> GenericImageDataAccessor:
                if mask.dtype == 'bool':
                    norm_mask = 1.0 * mask.data
                else:
                    norm_mask = mask.data / np.iinfo(mask.dtype).max
                norm_mask_acc = InMemoryDataAccessor(norm_mask.astype('float32'))
                return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch)
        return _Mod(params={'project_file': self.project_file})



class Error(Exception):
    pass

class IlastikInputEmbedding(Error):
    pass

class IlastikInputShapeError(Error):
    """Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape"""
    pass