Skip to content
Snippets Groups Projects
accessors.py 6.41 KiB
from abc import ABC, abstractmethod
import os
from pathlib import Path

import numpy as np
from skimage.io import imread

import czifile
import tifffile

from model_server.process import is_mask

class GenericImageDataAccessor(ABC):

    @abstractmethod
    def __init__(self):
        """
        Abstract base class that exposes an interfaces for image data, irrespective of whether it is instantiated
        from file I/O or other means.  Enforces X, Y, C, Z dimensions in that order.
        """
        pass

    @property
    def chroma(self):
        return self.shape_dict['C']

    @staticmethod
    def conform_data(data):
        if len(data.shape) > 4 or (0 in data.shape):
            raise DataShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z: {data.shape}')
        ones = [1 for i in range(0, 4 - len(data.shape))]
        return data.reshape(*data.shape, *ones)

    def is_3d(self):
        return True if self.shape_dict['Z'] > 1 else False

    def is_mask(self):
        return is_mask(self._data)

    def get_one_channel_data (self, channel: int):
        c = int(channel)
        return InMemoryDataAccessor(self.data[:, :, c:(c+1), :])

    @property
    def pixel_scale_in_micrometers(self):
        return {}

    @property
    def dtype(self):
        return self.data.dtype

    @property
    def hw(self):
        """
        Get data height and width as a tuple
        :return: tuple of (Y, X) dimensions
        """
        return self.shape_dict['Y'], self.shape_dict['X']

    @property
    def nz(self):
        return self.shape_dict['Z']

    @property
    def data(self):
        """
        Return data as 4d with axes in order of Y, X, C, Z
        :return: np.ndarray
        """
        return self._data

    @property
    def shape(self):
        return self._data.shape

    @property
    def shape_dict(self):
        return dict(zip(('Y', 'X', 'C', 'Z'), self.data.shape))

class InMemoryDataAccessor(GenericImageDataAccessor):
    def __init__(self, data):
        self._data = self.conform_data(data)

class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded from a file
    def __init__(self, fpath: Path):
        """
        Interface for image data that originates in an image file
        :param fpath: absolute path to image file
        :param kwargs: variable-length keyword arguments
        """
        if not os.path.exists(fpath):
            raise FileAccessorError(f'Could not find file at {fpath}')
        self.fpath = fpath

class TifSingleSeriesFileAccessor(GenericImageFileAccessor):
    def __init__(self, fpath: Path):
        super().__init__(fpath)

        try:
            tf = tifffile.TiffFile(fpath)
            self.tf = tf
        except Exception:
            FileAccessorError(f'Unable to access data in {fpath}')

        if len(tf.series) != 1:
            raise DataShapeError(f'Expect only one series in {fpath}')

        se = tf.series[0]

        order = ['Y', 'X', 'C', 'Z']
        axs = [a for a in se.axes if a in order]
        da = se.asarray()

        if 'C' not in axs:
            axs.append('C')
            da = np.expand_dims(da, len(da.shape))

        yxcz = np.moveaxis(
            da,
            [axs.index(k) for k in order],
            [0, 1, 2, 3]
        )

        self._data = self.conform_data(yxcz.reshape(yxcz.shape[0:4]))

    def __del__(self):
        self.tf.close()

class PngFileAccessor(GenericImageFileAccessor):
    def __init__(self, fpath: Path):
        super().__init__(fpath)

        try:
            arr = imread(fpath)
        except Exception:
            FileAccessorError(f'Unable to access data in {fpath}')

        if len(arr.shape) == 3: # rgb
            self._data = np.expand_dims(arr, 3)
        else: # mono
            self._data = np.expand_dims(arr, (2, 3))

class CziImageFileAccessor(GenericImageFileAccessor):
    """
    Image that is stored in a Zeiss .CZI file; may be multi-channel, and/or a z-stack,
    but not a time series or multiposition acquisition.
    """
    def __init__(self, fpath: Path):
        super().__init__(fpath)

        try:
            cf = czifile.CziFile(fpath)
            self.czifile = cf
        except Exception:
            raise FileAccessorError(f'Unable to access CZI data in {fpath}')

        sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes}
        if (sd.get('S') and (sd['S'] > 1)) or (sd.get('T') and (sd['T'] > 1)):
            raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}')

        idx = {k: sd[k] for k in ['Y', 'X', 'C', 'Z']}
        yxcz = np.moveaxis(
            cf.asarray(),
            [cf.axes.index(ch) for ch in idx],
            [0, 1, 2, 3]
        )
        self._data = self.conform_data(yxcz.reshape(yxcz.shape[0:4]))

    def __del__(self):
        self.czifile.close()

    @property
    def pixel_scale_in_micrometers(self):
        scale_meta = self.czifile.metadata(raw=False)['ImageDocument']['Metadata']['Scaling']['Items']['Distance']
        sc = {}
        for m in scale_meta:
            if m['DefaultUnitFormat'].encode() == b'\xc2\xb5m' and m['Id'] in self.shape_dict.keys():  # literal mu-m
                sc[m['Id']] = m['Value'] * 1e6
        return sc


def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor, mkdir=True) -> bool:
    if mkdir:
        fpath.parent.mkdir(parents=True, exist_ok=True)
    try:
        zcyx= np.moveaxis(
            accessor.data, # yxcz
            [3, 2, 0, 1],
            [0, 1, 2, 3]
        )
        if accessor.is_mask():
            tifffile.imwrite(fpath, zcyx.astype('uint8'), imagej=True)
        else:
            tifffile.imwrite(fpath, zcyx, imagej=True)
    except:
        raise FileWriteError(f'Unable to write data to file')
    return True


def generate_file_accessor(fpath):
    if str(fpath).upper().endswith('.TIF') or str(fpath).upper().endswith('.TIFF'):
        return TifSingleSeriesFileAccessor(fpath)
    elif str(fpath).upper().endswith('.CZI'):
        return CziImageFileAccessor(fpath)
    elif str(fpath).upper().endswith('.PNG'):
        return PngFileAccessor(fpath)
    else:
        raise FileAccessorError(f'Could not match a file accessor with {fpath}')

class Error(Exception):
    pass

class FileAccessorError(Error):
    pass

class FileNotFoundError(Error):
    pass

class DataShapeError(Error):
    pass

class FileWriteError(Error):
    pass

class InvalidAxisKey(Error):
    pass

class InvalidDataShape(Error):
    pass