Skip to content
Snippets Groups Projects
accessors.py 15.55 KiB
from abc import ABC, abstractmethod
import os
from pathlib import Path

import numpy as np
from skimage.io import imread, imsave

import czifile
import tifffile

from .process import make_rgb
from .process import is_mask

class GenericImageDataAccessor(ABC):

    axes = 'YXCZ'

    @abstractmethod
    def __init__(self):
        """
        Abstract base class that exposes an interfaces for image data, irrespective of whether it is instantiated
        from file I/O or other means.  Enforces X, Y, C, Z dimensions in that order.
        """
        pass

    @property
    def chroma(self):
        return self.shape_dict['C']

    @staticmethod
    def conform_data(data):
        if len(data.shape) > 4 or (0 in data.shape):
            raise DataShapeError(f'Cannot handle image with dimensions other than X, Y, C, and Z: {data.shape}')
        ones = [1 for i in range(0, 4 - len(data.shape))]
        return data.reshape(*data.shape, *ones)

    def is_3d(self):
        return True if self.shape_dict['Z'] > 1 else False

    def is_mask(self):
        return is_mask(self._data)

    def get_channels(self, channels: list, mip: bool = False):
        carr = [int(c) for c in channels]
        if mip:
            nda = self.data.take(indices=carr, axis=self._ga('C')).max(axis=self._ga('Z'), keepdims=True)
            return self._derived_accessor(nda)
        else:
            nda = self.data.take(indices=carr, axis=self._ga('C'))
            return self._derived_accessor(nda)


    def get_zi(self, zi: int):
        """
        Return a new accessor of a specific z-coordinate
        """
        return self._derived_accessor(
            self.data.take(
                indices=[zi],
                axis=self._ga('Z')
            )
        )

    def get_mip(self):
        """
        Return a new accessor of maximum intensity projection (MIP) along z-axis
        """
        return self.apply(lambda x: x.max(axis=self._ga('Z'), keepdims=True))

    def get_mono(self, channel: int, mip: bool = False):
        return self.get_channels([channel], mip=mip)

    def get_z_argmax(self):
        return self.apply(lambda x: x.argmax(axis=self.get_axis('Z')))

    def get_focus_vector(self):
        return self.data.sum(axis=(0, 1, 2))

    @property
    def data_xy(self) -> np.ndarray:
        if not self.chroma == 1 and self.nz == 1:
            raise InvalidDataShape('Can only return XY array from accessors with a single channel and single z-level')
        else:
            return self.data[:, :, 0, 0]

    @property
    def data_xyz(self) -> np.ndarray:
        if not self.chroma == 1:
            raise InvalidDataShape('Can only return XYZ array from accessors with a single channel')
        else:
            return self.data[:, :, 0, :]

    def _gc(self, channels):
        return self.get_channels(list(channels))

    def unique(self):
        return np.unique(self.data, return_counts=True)

    @property
    def pixel_scale_in_micrometers(self):
        return {}

    @property
    def dtype(self):
        return self.data.dtype

    def write(self, fp: Path, mkdir=True):
        write_accessor_data_to_file(fp, self, mkdir=mkdir)

    def get_axis(self, ch):
        return self.axes.index(ch.upper())

    def _ga(self, arg):
        return self.get_axis(arg)

    def crop_hw(self, yxhw: tuple):
        """
        Return subset of data cropped in X and Y
        :param yxhw: tuple (Y, X, H, W)
        :return: InMemoryDataAccessor of size (H x W), starting at (Y, X)
        """
        y, x, h, w = yxhw
        return InMemoryDataAccessor(self.data[y: (y + h), x: (x + w), :, :])

    @property
    def hw(self):
        """
        Get data height and width as a tuple
        :return: tuple of (Y, X) dimensions
        """
        return self.shape_dict['Y'], self.shape_dict['X']

    @property
    def nz(self):
        return self.shape_dict['Z']

    @property
    def data(self):
        """
        Return data as 4d with axes in order of Y, X, C, Z
        :return: np.ndarray
        """
        return self._data

    @property
    def shape(self):
        return self._data.shape

    @property
    def shape_dict(self):
        return dict(zip(('Y', 'X', 'C', 'Z'), self.data.shape))

    @staticmethod
    def _derived_accessor(data):
        """
        Create a new accessor given np.ndarray data; used for example in slicing operations
        """
        return InMemoryDataAccessor(data)

    def apply(self, func):
        """
        Apply func to data and return as a new in-memory accessor
        :param func: function that receives and returns the same size np.ndarray
        :return: InMemoryDataAccessor
        """
        return self._derived_accessor(
            func(self.data)
        )

    @property
    def info(self):
        return {
            'shape_dict': self.shape_dict,
            'dtype': str(self.dtype),
            'filepath': '',
        }

class InMemoryDataAccessor(GenericImageDataAccessor):
    def __init__(self, data):
        self._data = self.conform_data(data)

class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded from a file
    def __init__(self, fpath: Path):
        """
        Interface for image data that originates in an image file
        :param fpath: absolute path to image file
        :param kwargs: variable-length keyword arguments
        """
        if not os.path.exists(fpath):
            raise FileAccessorError(f'Could not find file at {fpath}')
        self.fpath = fpath

    @staticmethod
    def read(fp: Path):
        return generate_file_accessor(fp)

    @property
    def info(self):
        d = super().info
        d['filepath'] = self.fpath.__str__()
        return d

class TifSingleSeriesFileAccessor(GenericImageFileAccessor):
    def __init__(self, fpath: Path):
        super().__init__(fpath)

        try:
            tf = tifffile.TiffFile(fpath)
            self.tf = tf
        except Exception:
            raise FileAccessorError(f'Unable to access data in {fpath}')

        if len(tf.series) != 1:
            raise DataShapeError(f'Expect only one series in {fpath}')

        se = tf.series[0]

        order = ['Y', 'X', 'C', 'Z']
        axs = [a for a in se.axes if a in order]
        da = se.asarray()

        if 'C' not in axs:
            axs.append('C')
            da = np.expand_dims(da, len(da.shape))

        if 'Z' not in axs:
            axs.append('Z')
            da = np.expand_dims(da, len(da.shape))

        yxcz = np.moveaxis(
            da,
            [axs.index(k) for k in order],
            [0, 1, 2, 3]
        )

        self._data = self.conform_data(yxcz.reshape(yxcz.shape[0:4]))

    def __del__(self):
        self.tf.close()

class PngFileAccessor(GenericImageFileAccessor):
    def __init__(self, fpath: Path):
        super().__init__(fpath)

        try:
            arr = imread(fpath)
        except Exception:
            FileAccessorError(f'Unable to access data in {fpath}')

        if len(arr.shape) == 3: # rgb
            self._data = np.expand_dims(arr, 3)
        else: # mono
            self._data = np.expand_dims(arr, (2, 3))

class CziImageFileAccessor(GenericImageFileAccessor):
    """
    Image that is stored in a Zeiss .CZI file; may be multi-channel, and/or a z-stack,
    but not a time series or multiposition acquisition.
    """
    def __init__(self, fpath: Path):
        super().__init__(fpath)

        try:
            cf = czifile.CziFile(fpath)
            self.czifile = cf
        except Exception:
            raise FileAccessorError(f'Unable to access CZI data in {fpath}')

        try:
            md = cf.metadata(raw=False)
            compmet = md['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod']
        except KeyError:
            raise InvalidCziCompression('Could not find metadata key OriginalCompressionMethod')
        if compmet.upper() != 'UNCOMPRESSED':
            raise InvalidCziCompression(f'Unsupported compression method {compmet}')

        sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes}
        if (sd.get('S') and (sd['S'] > 1)) or (sd.get('T') and (sd['T'] > 1)):
            raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}')

        idx = {k: sd[k] for k in ['Y', 'X', 'C', 'Z']}
        yxcz = np.moveaxis(
            cf.asarray(),
            [cf.axes.index(ch) for ch in idx],
            [0, 1, 2, 3]
        )
        self._data = self.conform_data(yxcz.reshape(yxcz.shape[0:4]))

    def __del__(self):
        self.czifile.close()

    @property
    def pixel_scale_in_micrometers(self):
        scale_meta = self.czifile.metadata(raw=False)['ImageDocument']['Metadata']['Scaling']['Items']['Distance']
        sc = {}
        for m in scale_meta:
            if m['DefaultUnitFormat'].encode() == b'\xc2\xb5m' and m['Id'] in self.shape_dict.keys():  # literal mu-m
                sc[m['Id']] = m['Value'] * 1e6
        return sc


def write_accessor_data_to_file(fpath: Path, acc: GenericImageDataAccessor, mkdir=True) -> bool:
    """
    Export an image accessor to file
    :param fpath: complete path including filename and extension
    :param acc: image accessor to be written
    :param mkdir: create any needed subdirectories in fpath if True
    :return: True
    """
    if 'P' in acc.shape_dict.keys():
        raise FileWriteError(f'Can only write single-position accessor to file')
    ext = fpath.suffix.upper()

    if mkdir:
        fpath.parent.mkdir(parents=True, exist_ok=True)

    if ext == '.PNG':
        if acc.dtype != 'uint8':
            raise FileWriteError(f'Invalid data type {acc.dtype}')
        if acc.chroma == 1:
            data = acc.data[:, :, 0, 0]
        elif acc.chroma == 2:  # add a blank blue channel
            data = make_rgb(acc.data)[:, :, :, 0]
        else:  # preserve RGB order
            data = acc.data[:, :, :, 0]
        imsave(fpath, data, check_contrast=False)
        return True

    elif ext in ['.TIF', '.TIFF']:
        zcyx= np.moveaxis(
            acc.data,  # yxcz
            [3, 2, 0, 1],
            [0, 1, 2, 3]
        )
        if acc.is_mask():
            if acc.dtype == 'bool':
                data = (zcyx * 255).astype('uint8')
            else:
                data = zcyx.astype('uint8')
            tifffile.imwrite(fpath, data, imagej=True)
        else:
            tifffile.imwrite(fpath, zcyx, imagej=True)
    else:
        raise FileWriteError(f'Unable to write data to file of extension {ext}')
    return True


def generate_file_accessor(fpath):
    """
    Given an image file path, return an image accessor, assuming the file is a supported format and represents
    a single position array, which may be single or multichannel, single plane or z-stack.
    """
    if str(fpath).upper().endswith('.TIF') or str(fpath).upper().endswith('.TIFF'):
        return TifSingleSeriesFileAccessor(fpath)
    elif str(fpath).upper().endswith('.CZI'):
        return CziImageFileAccessor(fpath)
    elif str(fpath).upper().endswith('.PNG'):
        return PngFileAccessor(fpath)
    else:
        raise FileAccessorError(f'Could not match a file accessor with {fpath}')


class PatchStack(InMemoryDataAccessor):

    axes = 'PYXCZ'

    def __init__(self, data, force_ydim_longest=False):
        """
        A sequence of n (generally) color 3D images of the same size
        :param data: either a list of np.ndarrays of size YXCZ, or np.ndarray of size PYXCZ
        :param force_ydmin_longest: if creating a PatchStack from a list of different-sized patches, rotate each
            as needed so that height is always greater than or equal to width
        """
        self._slices = []
        if isinstance(data, list):  # list of YXCZ patches
            n = len(data)
            if force_ydim_longest:
                psh = np.array([e.shape[0:2] for e in data]).max(axis=1).max()
                psw = np.array([e.shape[0:2] for e in data]).min(axis=1).max()
                psc, psz = np.array([e.shape[2:] for e in data]).max(axis=0)
                yxcz_shape = np.array([psh, psw, psc, psz])
            else:
                yxcz_shape = np.array([e.shape for e in data]).max(axis=0)
            nda = np.zeros(
                (n, *yxcz_shape), dtype=data[0].dtype
            )
            for i in range(0, len(data)):
                h, w = data[i].shape[0:2]
                if force_ydim_longest and w > h:
                    patch = np.rot90(data[i], axes=(0, 1))
                else:
                    patch = data[i]
                s = tuple([slice(0, c) for c in patch.shape])
                nda[i][s] = patch
                self._slices.append(s)

        elif isinstance(data, np.ndarray) and len(data.shape) == 5:  # interpret as PYXCZ
            nda = data
            for i in range(0, len(data)):
                self._slices.append(tuple([slice(0, c) for c in data[i].shape]))
        else:
            raise InvalidDataForPatchStackError(f'Cannot create accessor from {type(data)}')

        assert nda.ndim == 5
        self._data = nda

    @staticmethod
    def _derived_accessor(data):
        return PatchStack(data)

    def get_slice_at(self, i):
        return self._slices[i]

    def iat(self, i, crop=False):
        if crop:
            return InMemoryDataAccessor(self.data[i, :, :, :, :][self._slices[i]])
        else:
            return InMemoryDataAccessor(self.data[i, :, :, :, :])

    def iat_yxcz(self, i, crop=False):
        return self.iat(i, crop=crop)

    @property
    def count(self):
        return self.shape_dict['P']

    def export_pyxcz(self, fpath: Path):
        tzcyx = np.moveaxis(
            self.pyxcz,  # yxcz
            [0, 4, 3, 1, 2],
            [0, 1, 2, 3, 4]
        )

        if self.is_mask():
            if self.dtype == 'bool':
                data = (tzcyx * 255).astype('uint8')
            else:
                data = tzcyx.astype('uint8')
            tifffile.imwrite(fpath, data, imagej=True)
        else:
            tifffile.imwrite(fpath, tzcyx, imagej=True)

    @property
    def shape_dict(self):
        return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))

    def get_list(self):
        n = self.nz
        return [self.data[:, :, 0, zi] for zi in range(0, n)]

    @property
    def pyxcz(self):
        return self.data

    @property
    def pczyx(self):
        return np.moveaxis(
            self.data,
            [0, 3, 4, 1, 2],
            [0, 1, 2, 3, 4]
        )

    @property
    def shape(self):
        return self.data.shape

    @property
    def shape_dict(self):
        return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))


def make_patch_stack_from_file(fpath):  # interpret t-dimension as patch position
    if not Path(fpath).exists():
        raise FileNotFoundError(f'Could not find {fpath}')

    try:
        tf = tifffile.TiffFile(fpath)
    except Exception:
        raise FileAccessorError(f'Unable to access data in {fpath}')

    if len(tf.series) != 1:
        raise DataShapeError(f'Expect only one series in {fpath}')

    se = tf.series[0]

    axs = [a for a in se.axes if a in [*'TZCYX']]
    sd = dict(zip(axs, se.shape))
    for a in [*'TZC']:
        if a not in axs:
            sd[a] = 1
    tzcyx = se.asarray().reshape([sd[k] for k in [*'TZCYX']])

    pyxcz = np.moveaxis(
        tzcyx,
        [0, 3, 4, 2, 1],
        [0, 1, 2, 3, 4],
    )
    return PatchStack(pyxcz)


class Error(Exception):
    pass

class FileAccessorError(Error):
    pass

class FileNotFoundError(Error):
    pass

class DataShapeError(Error):
    pass

class FileWriteError(Error):
    pass

class InvalidAxisKey(Error):
    pass

class InvalidCziCompression(Error):
    pass

class InvalidDataShape(Error):
    pass

class InvalidDataForPatchStackError(Error):
    pass