-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
zmask.py 13.42 KiB
from uuid import uuid4
import numpy as np
import pandas as pd
from skimage.measure import find_contours, label, regionprops_table
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
from extensions.chaeo.params import RoiFilter, RoiSetExportParams
from extensions.chaeo.process import mask_largest_object
from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import InstanceSegmentationModel
class RoiSet(object):
def __init__(
self,
acc_mask: GenericImageDataAccessor,
acc_raw: GenericImageDataAccessor,
filters=None,
mask_type='contours',
expand_box_by=(0, 0),
):
self.zmask, self.zmask_meta, self.df, self.interm = build_zmask_from_object_mask(
acc_mask,
acc_raw,
filters=filters,
mask_type=mask_type,
expand_box_by=expand_box_by
)
self.acc_raw = acc_raw
self.count = len(self.zmask_meta)
self.object_id_labels = self.interm['label_map']
def get_argmax(self):
return self.interm.argmax
def export_3d_patches(self, where, prefix, channel, params: RoiSetExportParams):
if not self.count:
return
files = export_patches_from_zstack(
where,
self.acc_raw.get_one_channel_data(channel),
self.zmask_meta,
prefix=prefix,
draw_bounding_box=params.draw_bounding_box,
rescale_clip=params.rescale_clip,
make_3d=True,
)
def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams):
if not self.count:
return
files = export_multichannel_patches_from_zstack(
where,
self.acc_raw.get_one_channel_data(channel),
self.zmask_meta,
prefix=prefix,
draw_bounding_box=params.draw_bounding_box,
rescale_clip=params.rescale_clip,
make_3d=False,
focus_metric=params.focus_metric,
ch_white=channel,
ch_rgb_overlay=params.rgb_overlay_channels,
bounding_box_channel=1,
bounding_box_linewidth=2,
draw_contour=params.draw_contour,
draw_mask=params.draw_mask,
overlay_gain=params.rgb_overlay_weights,
)
df_patches = pd.DataFrame(files)
self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1)
def export_2d_patches_for_training(self, where, prefix, channel, params: RoiSetExportParams):
if not self.count:
return
files = export_multichannel_patches_from_zstack(
where,
self.acc_raw.get_one_channel_data(channel),
self.zmask_meta,
prefix=prefix,
draw_bounding_box=params.draw_bounding_box,
rescale_clip=params.rescale_clip,
make_3d=False,
focus_metric=params.focus_metric,
ch_white=channel,
ch_rgb_overlay=params.rgb_overlay_channels,
bounding_box_channel=1,
bounding_box_linewidth=2,
draw_contour=params.draw_contour,
draw_mask=params.draw_mask,
overlay_gain=params.rgb_overlay_weights,
)
df_patches = pd.DataFrame(files)
self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1)
def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams):
if not self.count:
return
files = export_multichannel_patches_from_zstack(
where,
self.acc_raw.get_one_channel_data(channel),
self.zmask_meta,
prefix=prefix,
rescale_clip=params.rescale_clip,
make_3d=False,
focus_metric=params.focus_metric,
)
def export_patch_masks(self, where, prefix, channel, params: RoiSetExportParams):
if not self.count:
return
files = export_patch_masks_from_zstack(
where,
self.acc_raw.get_one_channel_data(channel),
self.zmask_meta,
prefix=prefix,
)
def export_annotated_zstack(self, where, prefix, channel, params: RoiSetExportParams):
annotated = InMemoryDataAccessor(
draw_boxes_on_3d_image(
self.acc_raw.get_one_channel_data(channel).data,
self.zmask_meta,
add_label=params.draw_label,
)
)
write_accessor_data_to_file(
where / 'annotated_zstacks' / (prefix + '.tif'),
annotated
)
def get_multichannel_projection(self):
dff = self.df[self.df['keeper']]
if self.count:
projected = project_stack_from_focal_points(
dff['centroid-0'].to_numpy(),
dff['centroid-1'].to_numpy(),
dff['zi'].to_numpy(),
self.acc_raw,
degree=4,
)
else: # else just return MIP
projected = self.acc_raw.data.max(axis=-1)
return projected
def get_raw_patches(self, channel):
return get_patches_from_zmask_meta(self.acc_raw, self.zmask_meta).get_one_channel_data(channel)
def get_patch_masks(self):
return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta)
def classify_by(self, channel, object_classification_model: InstanceSegmentationModel):
# do this on a patch basis, i.e. only one object per frame
obmap_patches = object_classification_model.label_instance_class(
self.get_raw_patches(channel),
self.get_patch_masks()
)
lamap = self.object_id_labels
output_map = np.zeros(lamap.shape, dtype=lamap.dtype)
self.df['instance_class'] = np.nan
# assign labels to object map:
for ii in range(0, self.count):
object_id = self.zmask_meta[ii]['info'].label
result_patch = mask_largest_object(obmap_patches.iat(ii))
object_class = np.unique(result_patch)[1]
output_map[self.object_id_labels == object_id] = object_class
self.df[object_id, 'instance_class'] = object_class
return InMemoryDataAccessor(output_map)
# TODO: test
def get_object_mask_by_id(self, obj_id):
return self.object_id_labels == obj_id
def get_object_mask_by_class(self, class_id):
return self.object_id_labels == class_id
# TODO: implement
def get_object_patch_by_id(self, obj_id):
pass
def get_object_map(self, filters: RoiFilter):
pass
def build_zmask_from_object_mask(
obmask: GenericImageDataAccessor,
zstack: GenericImageDataAccessor,
filters=None,
mask_type='contours',
expand_box_by=(0, 0),
):
"""
Given a 2D mask of objects, build a 3D mask, where each object's z-position is determined by the index of
maximum intensity in z. Return this zmask and a list of each object's meta information.
:param obmask: GenericImageDataAccessor monochrome 2D inary mask of objects
:param zstack: GenericImageDataAccessor monochrome zstack of same Y, X dimension as obmask
:param filters: dictionary of form {attribute: (min, max)}; valid attributes are 'area' and 'solidity'
:param mask_type: if 'boxes', zmask is True in each object's complete bounding box; otherwise 'contours'
:param expand_box_by: (xy, z) expands bounding box by (xy, z) pixels except where this hits a boundary
:return: tuple (zmask, meta)
np.ndarray:
boolean mask of same size as stack
List containing one Dict per object, with keys:
info: object's properties from skimage.measure.regionprops_table, including bounding box (y0, y1, x0, x1)
slice: named slice (np.s_) of (optionally) expanded bounding box
relative_bounding_box: bounding box (y0, y1, x0, x1) in relative frame of (optionally) expanded bounding box
contour: object's contour returned by skimage.measure.find_contours
mask: mask of object in relative frame of (optionally) expanded bounding box
pd.DataFrame: objects, including bounding, box information after filtering
Dict of intermediate image products:
label_map: np.ndarray (h x w) where each unique object has an integer label
argmax: np.ndarray (h x w x 1 x 1) z-index of highest intensity in zstack
"""
# validate inputs
# assert zstack.chroma == 1
assert mask_type in ('contours', 'boxes'), mask_type
assert obmask.is_mask()
assert obmask.chroma == 1
assert obmask.nz == 1
assert zstack.hw == obmask.hw
# assign object labels and build object query
lamap = label(obmask.data[:, :, 0, 0]).astype('uint16')
query_str = 'label > 0' # always true
if filters is not None:
for k in filters.keys():
assert k in ('area', 'solidity')
vmin, vmax = filters[k]
assert vmin >= 0
query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
# build dataframe of objects, assign z index to each object
argmax = zstack.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
df = (
pd.DataFrame(
regionprops_table(
lamap,
intensity_image=argmax,
properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid')
)
)
.rename(
columns={
'bbox-0': 'y0',
'bbox-1': 'x0',
'bbox-2': 'y1',
'bbox-3': 'x1',
}
)
)
df['zi'] = df['intensity_mean'].round().astype('int')
df['keeper'] = False
df.loc[df.query(query_str).index, 'keeper'] = True
# make an object map where label is replaced by focus position in stack and background is -1
lut = np.zeros(lamap.max() + 1) - 1
lut[df.label] = df.zi
# convert bounding boxes to numpy slice objects
ebxy, ebz = expand_box_by
h, w, c, nz = zstack.shape
meta = []
for ob in df[df['keeper']].itertuples(name='LabeledObject'):
y0 = max(ob.y0 - ebxy, 0)
y1 = min(ob.y1 + ebxy, h)
x0 = max(ob.x0 - ebxy, 0)
x1 = min(ob.x1 + ebxy, w)
z0 = max(ob.zi - ebz, 0)
z1 = min(ob.zi + ebz, nz)
# relative bounding box positions
rbb = {
'y0': ob.y0 - y0,
'y1': ob.y1 - y0,
'x0': ob.x0 - x0,
'x1': ob.x1 - x0,
}
sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1]
# compute contours
obmask = (lamap == ob.label)
contour = find_contours(obmask)
mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1]
assert rbb['x1'] <= (x1 - x0)
assert rbb['y1'] <= (y1 - y0)
meta.append({
'df_index': ob.Index,
'info': ob,
'slice': sl,
'relative_bounding_box': rbb,
'contour': contour,
'mask': mask
})
# build mask z-stack
zi_st = np.zeros(zstack.shape, dtype='bool')
if mask_type == 'contours':
zi_map = (lut[lamap] + 1.0).astype('int')
idxs = np.array(zi_map) - 1
np.put_along_axis(
zi_st,
np.expand_dims(idxs, (2, 3)),
1,
axis=3
)
# change background level from to 0 in final frame
zi_st[:, :, :, -1][lamap == 0] = 0
elif mask_type == 'boxes':
for bb in meta:
sl = bb['slice']
zi_st[sl] = 1
# return intermediate image arrays
interm = {
'label_map': lamap,
'argmax': argmax,
}
return zi_st, meta, df, interm
def project_stack_from_focal_points(
xx: np.ndarray,
yy: np.ndarray,
zz: np.ndarray,
stack: GenericImageDataAccessor,
degree: int = 2,
) -> np.ndarray:
"""
Given a set of 3D points, project a multichannel z-stack based on a surface fit of the provided points
:param xx: vector of point x-coordinates
:param yy: vector of point y-coordinates
:param zz: vector of point z-coordinates
:param stack: z-stack to project
:param degree: order of polynomial to fit
:return: multichannel 2d projected image array
"""
assert xx.shape == yy.shape
assert xx.shape == zz.shape
poly = PolynomialFeatures(degree=degree)
X = np.stack([xx, yy]).T
features = poly.fit_transform(X, zz)
model = LinearRegression(fit_intercept=False)
model.fit(features, zz)
xy_indices = np.indices(stack.hw).reshape(2, -1).T
xy_features = np.dot(
poly.fit_transform(xy_indices, zz),
model.coef_
)
zi_image = xy_features.reshape(
stack.hw
).round().clip(
0, (stack.nz - 1)
).astype('uint16')
return np.take_along_axis(
stack.data,
np.repeat(
np.expand_dims(zi_image, (2, 3)),
stack.chroma,
axis=2
),
axis=3
)