Skip to content
Snippets Groups Projects
zmask.py 6.31 KiB
Newer Older
import numpy as np
import pandas as pd

from skimage.measure import find_contours, label, regionprops_table
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

from model_server.accessors import GenericImageDataAccessor
def build_zmask_from_object_mask(
        obmask: GenericImageDataAccessor,
        zstack: GenericImageDataAccessor,
        filters=None,
        mask_type='contour',
    Given a 2D mask of objects, build a 3D mask, where each object's z-position is determined by the index of
    maximum intensity in z.  Return this zmask and a list of each object's meta information.
    :param obmask: GenericImageDataAccessor monochrome 2D inary mask of objects
    :param zstack: GenericImageDataAccessor monochrome zstack of same Y, X dimension as obmask
    :param filters: dictionary of form {attribute: (min, max)}; valid attributes are 'area' and 'solidity'
    :param mask_type: if 'boxes', zmask is True in each object's complete bounding box; otherwise 'contours'
    :param expand_box_by: (xy, z) expands bounding box by (xy, z) pixels except where this hits a boundary
    :return: tuple (zmask, meta)
        np.ndarray:
            boolean mask of same size as stack
        List containing one Dict per object, with keys:
            info: object's properties from skimage.measure.regionprops_table, including bounding box (y0, y1, x0, x1)
            slice: named slice (np.s_) of (optionally) expanded bounding box
            relative_bounding_box: bounding box (y0, y1, x0, x1) in relative frame of (optionally) expanded bounding box
            contour: object's contour returned by skimage.measure.find_contours
            mask: mask of object in relative frame of (optionally) expanded bounding box
        pd.DataFrame: objects, including bounding, box information after filtering
        Dict of intermediate image products:
            label_map: np.ndarray (h x w) where each unique object has an integer label
            argmax: np.ndarray (h x w x 1 x 1) z-index of highest intensity in zstack
    assert zstack.chroma == 1
    assert zstack.nz > 1
    assert mask_type in ('contours', 'boxes'), mask_type
    assert obmask.is_mask()
    assert obmask.chroma == 1
    assert obmask.nz == 1
    assert zstack.hw == obmask.hw
    # assign object labels and build object query
    lamap = label(obmask.data[:, :, 0, 0]).astype('uint16')
    query_str = 'label > 0'  # always true
    if filters is not None:
        for k in filters.keys():
            assert k in ('area', 'solidity')
            vmin, vmax = filters[k]
            assert vmin >= 0
            query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'

    # build dataframe of objects, assign z index to each object
    argmax = zstack.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
    df = (
        pd.DataFrame(
            regionprops_table(
                intensity_image=argmax,
                properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid')
            )
        )
        .rename(
            columns={
                'bbox-0': 'y0',
                'bbox-1': 'x0',
                'bbox-2': 'y1',
                'bbox-3': 'x1',
            }
        )
    )
    df['zi'] = df['intensity_mean'].round().astype('int')
    df['keeper'] = False
    df.loc[df.query(query_str).index, 'keeper'] = True

    # make an object map where label is replaced by focus position in stack and background is -1
    lut = np.zeros(lamap.max() + 1) - 1
    lut[df.label] = df.zi

    # convert bounding boxes to numpy slice objects
    ebxy, ebz = expand_box_by
    for ob in df[df['keeper']].itertuples(name='LabeledObject'):
        y0 = max(ob.y0 - ebxy, 0)
        y1 = min(ob.y1 + ebxy, h - 1)
        x0 = max(ob.x0 - ebxy, 0)
        x1 = min(ob.x1 + ebxy, w - 1)
        z0 = max(ob.zi - ebz, 0)
        z1 = min(ob.zi + ebz, nz)

        # relative bounding box positions
        rbb = {
            'y0': ob.y0 - y0,
            'y1': ob.y1 - y0,
            'x0': ob.x0 - x0,
            'x1': ob.x1 - x0,
        }

        sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1]
        obmask = (lamap == ob.label)
        contour = find_contours(obmask)
        mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1]

            'info': ob,
            'slice': sl,
            'relative_bounding_box': rbb,
            'contour': contour,
            'mask': mask
        })

    # build mask z-stack
    zi_st = np.zeros(zstack.shape, dtype='bool')
    if mask_type == 'contours':
        zi_map = (lut[lamap] + 1.0).astype('int')
        idxs = np.array(zi_map) - 1
        np.put_along_axis(
            zi_st,
            np.expand_dims(idxs, (2, 3)),
            1,
            axis=3
        )

        # change background level from to 0 in final frame
        zi_st[:, :, :, -1][lamap == 0] = 0
    elif mask_type == 'boxes':
            sl = bb['slice']
            zi_st[sl] = 1

    # return intermediate image arrays
    interm = {
        'label_map': lamap,
        'argmax': argmax,
    }

    return zi_st, meta, df, interm


def project_stack_from_focal_points(
        xx: np.ndarray,
        yy: np.ndarray,
        zz: np.ndarray,
        stack: GenericImageDataAccessor,
        degree: int = 2,
):
    # TODO: add weights
    """
    Given a set of 3D points, project a multichannel z-stack
    :param xx:
    :param yy:
    :param zz:
    :param stack:
    :param degree:
    :return:
    """
    assert xx.shape == yy.shape
    assert xx.shape == zz.shape
    assert stack.chroma == 1

    poly = PolynomialFeatures(degree=degree)
    X = np.stack([xx, yy]).T
    features = poly.fit_transform(X, zz)
    model = LinearRegression(fit_intercept=False)
    model.fit(features, zz)

    output_shape = stack.hw
    xy_indices = np.indices(output_shape).reshape(2, -1).T
    xy_features = np.dot(
        poly.fit_transform(xy_indices, zz),
        model.coef_
    )
    zi_image = xy_features.reshape(
        output_shape
    ).round().clip(
        0, stack.nz
    ).astype('uint16')

    return np.take_along_axis(
        stack.data[:, :, 0, :],
        np.expand_dims(zi_image, 2),
        axis=2
    ).reshape(output_shape)