Skip to content
Snippets Groups Projects
models.py 7.22 KiB
Newer Older
import extensions.ilastik.conf
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel

    def __init__(self, params, autoload=True):
        self.project_file = Path(params['project_file'])
        params['project_file'] = self.project_file.__str__()
        if self.project_file.is_absolute():
            pap = self.project_file
        else:
            pap = extensions.ilastik.conf.paths['project_files'] / self.project_file
        self.project_file_abspath = pap
        if not pap.exists():
            raise FileNotFoundError(f'Project file does not exist: {pap}')
        if 'project_file' not in params or not self.project_file_abspath.exists():
            raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')

        from ilastik import app
        from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo

        self.PreloadedArrayDatasetInfo = PreloadedArrayDatasetInfo

        os.environ["LAZYFLOW_THREADS"] = "8"
        os.environ["LAZYFLOW_TOTAL_RAM_MB"] = "24000"

        args = app.parse_args([])
        args.headless = True
        shell = app.main(args, init_logging=False)
        if not isinstance(shell.workflow, self.get_workflow()):
            raise ParameterExpectedError(
                f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
            )
        self.shell = shell

class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
    model_id = 'ilastik_pixel_classification'
        from ilastik.workflows import PixelClassificationWorkflow
    def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
        pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
        assert len(pxmaps) == 1, 'ilastik generated more than one pixel map'
        return InMemoryDataAccessor(data=yxcz), {'success': True}
    def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs):
        mask = pxmap.data[:, :, px_class, :] > px_prob_threshold

class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
    model_id = 'ilastik_object_classification_from_segmentation'
        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
        return ObjectClassificationWorkflowBinary
    def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
        assert segmentation_img.is_mask()
        if segmentation_img.dtype == 'bool':
            seg = 255 * segmentation_img.data.astype('uint8')
            tagged_seg_data = vigra.taggedView(
                255 * segmentation_img.data.astype('uint8'),
                'yxcz'
            )
        else:
            tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'
        return InMemoryDataAccessor(data=yxcz), {'success': True}

    def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
        super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
        obmap, _ = self.infer(img, mask)
        return obmap

class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel):
    model_id = 'ilastik_object_classification_from_pixel_predictions'

    @staticmethod
    def get_workflow():
        from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
        return ObjectClassificationWorkflowPrediction
    def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
        tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
        tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')

        dsi = [
            {
                'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
                'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
            }
        ]

        obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]

        assert len(obmaps) == 1, 'ilastik generated more than one object map'

        yxcz = np.moveaxis(
            obmaps[0],
            [1, 2, 3, 0],
            [0, 1, 2, 3]
        )
        return InMemoryDataAccessor(data=yxcz), {'success': True}

    def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs):
        """
        Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is
        assigned a class.
        :param img: input image
        :param pxmap: map of pixel probabilities
        :param kwargs:
            pixel_classification_channel: channel of pxmap used to segment objects
            pixel_classification_thresold: threshold of pxmap used to segment objects
        :return:
        """
        if not img.shape == pxmap.shape:
            raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
        # TODO: check that pxmap is in-range
        pxch = kwargs.get('pixel_classification_channel', 0)
        pxtr = kwargs('pixel_classification_threshold', 0.5)
        mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
        # super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)