Skip to content
Snippets Groups Projects
Commit 91333047 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Progress on assembling RGB exports

parent cf20671e
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,17 @@ def _focus_metrics():
'moment': lambda x: moment(x.flatten(), moment=2),
}
def _safe_add(a, g, b):
assert a.dtype == b.dtype
assert a.shape == b.shape
assert g >= 0.0
return np.clip(
a.astype('uint32') + g * b.astype('uint32'),
0,
np.iinfo(a.dtype).max
).astype(a.dtype)
def write_patch_to_file(where, fname, yxcz):
ext = fname.split('.')[-1].upper()
where.mkdir(parents=True, exist_ok=True)
......@@ -71,7 +82,7 @@ def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
return MonoPatchStack(patches)
def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask') -> list:
def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list:
patches_acc = get_patch_masks(roiset, pad_to=pad_to)
exported = []
......@@ -90,11 +101,44 @@ def get_patches_from_zmask_meta(
pad_to: int = 256,
make_3d: bool = False,
focus_metric: str = None,
rgb_overlay_channels: list = None,
rgb_overlay_weights: list = [1.0, 1.0, 1.0],
white_channel: int = None,
**kwargs
) -> pd.DataFrame:
# arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
raw = roiset.acc_raw
if isinstance(rgb_overlay_channels, list) and isinstance(rgb_overlay_weights, list):
assert all([c < raw.chroma for c in rgb_overlay_channels])
assert len(rgb_overlay_channels) == 3
assert len(rgb_overlay_weights) == 3
if white_channel:
assert white_channel < raw.chroma
stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
else:
stack = np.zeros([*raw.shape[0:1], 3, raw.shape[3]], dtype=raw.dtype)
for ii, ci in enumerate(rgb_overlay_channels):
if ci is None:
continue
assert isinstance(ci, int)
assert ci < raw.chroma
stack[:, :, ii, :] = _safe_add(
stack[:, :, ii, :], # either black or grayscale channel
rgb_overlay_weights[ii],
stack[:, :, ci, :]
)
else:
if white_channel:
assert white_channel < raw.chroma
stack = raw.data[:, :, [white_channel], :]
else:
stack = raw.data
def _make_patch(roi):
patch3d = roiset.acc_raw.data[roi.slice]
patch3d = stack[roi.slice]
ph, pw, pc, pz = patch3d.shape
subpatch = patch3d[roi.relative_slice]
......@@ -102,7 +146,7 @@ def get_patches_from_zmask_meta(
if make_3d:
patch = patch3d
# make a 2d patch, find optimal z-position determined by focus_metric function
# make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately
elif focus_metric is not None:
foc = _focus_metrics()[focus_metric]
......@@ -119,7 +163,6 @@ def get_patches_from_zmask_meta(
patch = patch3d[:, :, :, [zim]]
assert len(patch.shape) == 4
assert patch.shape[2] == roiset.acc_raw.chroma
if rescale_clip is not None:
patch = rescale(patch, rescale_clip)
......@@ -206,48 +249,52 @@ def export_multichannel_patches_from_zstack(
roiset,
rgb_overlay_channels: list = None,
rgb_overlay_weights: list = [1.0, 1.0, 1.0],
ch_white: int = None,
rgb_white_channel: int = None,
**kwargs
):
"""
Export RGB patches where each patch is assignable to a channel of the input stack
:param ch_rgb_overlay: tuple of integers (R, G, B) that assign a stack channel index to an RGB channel
:param overlay_gain: optional, tuple of float (R, G, B) multipliers that can be used to balance relative brightness
:param ch_white: int, index of stack channel that becomes grayscale signal in export patches
:param rgb_white_channel: int, index of stack channel that becomes grayscale signal in export patches
"""
def _safe_add(a, g, b):
assert a.dtype == b.dtype
assert a.shape == b.shape
assert g >= 0.0
return np.clip(
a.astype('uint32') + g * b.astype('uint32'),
0,
np.iinfo(a.dtype).max
).astype(a.dtype)
idata = roiset.acc_raw.data
if ch_white:
assert ch_white < roiset.acc_raw.chroma
mdata = idata[:, :, [ch_white, ch_white, ch_white], :]
else:
mdata = idata
if rgb_overlay_channels:
assert len(rgb_overlay_channels) == 3
assert len(rgb_overlay_weights) == 3
for ii, ci in enumerate(rgb_overlay_channels):
if ci is None:
continue
assert isinstance(ci, int)
assert ci < roiset.acc_raw.chroma
mdata[:, :, ii, :] = _safe_add(
mdata[:, :, ii, :],
rgb_overlay_weights[ii],
idata[:, :, ci, :]
)
# TODO: this is a bit of a workaround
mstack = InMemoryDataAccessor(mdata)
rgb_roiset = RoiSet(roiset.acc_obj_ids, mstack, roiset.params)
return export_patches_from_zstack(where, rgb_roiset, **kwargs)
\ No newline at end of file
# def _safe_add(a, g, b):
# assert a.dtype == b.dtype
# assert a.shape == b.shape
# assert g >= 0.0
#
# return np.clip(
# a.astype('uint32') + g * b.astype('uint32'),
# 0,
# np.iinfo(a.dtype).max
# ).astype(a.dtype)
#
# idata = roiset.acc_raw.data
# if rgb_white_channel:
# assert rgb_white_channel < roiset.acc_raw.chroma
# mdata = idata[:, :, [rgb_white_channel, rgb_white_channel, rgb_white_channel], :]
# else:
# mdata = idata
#
# if rgb_overlay_channels:
# assert len(rgb_overlay_channels) == 3
# assert len(rgb_overlay_weights) == 3
# for ii, ci in enumerate(rgb_overlay_channels):
# if ci is None:
# continue
# assert isinstance(ci, int)
# assert ci < roiset.acc_raw.chroma
# mdata[:, :, ii, :] = _safe_add(
# mdata[:, :, ii, :],
# rgb_overlay_weights[ii],
# idata[:, :, ci, :]
# )
#
# # TODO: this is a bit of a workaround
# mstack = InMemoryDataAccessor(mdata)
# rgb_roiset = RoiSet(roiset.acc_obj_ids, mstack, roiset.params)
# return export_patches_from_zstack(where, rgb_roiset, **kwargs)
kwargs['white_channel'] = rgb_white_channel
kwargs['rgb_overlay_channels'] = rgb_overlay_channels
kwargs['rgb_overlay_weights'] = rgb_overlay_weights
return export_patches_from_zstack(where, roiset, **kwargs)
\ No newline at end of file
......@@ -182,7 +182,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
# InMemoryDataAccessor(self.stack.data),
# roiset.zmask_meta,
roiset,
ch_white=4,
rgb_white_channel=4,
draw_bounding_box=True,
bounding_box_channel=1,
)
......@@ -198,7 +198,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
# InMemoryDataAccessor(self.stack.data),
# roiset.zmask_meta,
roiset,
ch_white=4,
rgb_white_channel=4,
ch_rgb_overlay=(3, None, None),
draw_mask=True,
mask_channel=0,
......@@ -215,7 +215,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
output_path / '2d_patches_chlorophyl_contour_overlay',
# InMemoryDataAccessor(self.stack.data),
roiset,
ch_white=4,
rgb_white_channel=4,
ch_rgb_overlay=(3, None, None),
draw_contour=True,
contour_channel=1,
......
......@@ -269,12 +269,12 @@ class RoiSet(object):
)
if k == 'patches_2d_for_annotation':
files = export_multichannel_patches_from_zstack(
subdir, self.acc_raw, self.zmask_meta, prefix=pr, make_3d=False, ch_white=channel,
subdir, self.acc_raw, self.zmask_meta, prefix=pr, make_3d=False, rgb_white_channel=channel,
bounding_box_channel=1, bounding_box_linewidth=2, **kp,
)
if k == 'patches_2d_for_training':
files = export_multichannel_patches_from_zstack(
subdir, self.acc_raw, self.zmask_meta, ch_white=channel, prefix=pr, make_3d=False, **kp
subdir, self.acc_raw, self.zmask_meta, rgb_white_channel=channel, prefix=pr, make_3d=False, **kp
)
df_patches = pd.DataFrame(files)
self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment