-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
workflows.py 13.50 KiB
from pathlib import Path
from typing import Dict, List
from uuid import uuid4
import numpy as np
import pandas as pd
from skimage.morphology import dilation
from sklearn.model_selection import train_test_split
from extensions.chaeo.accessors import MonoPatchStack
from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.models import PatchStackObjectClassifier
from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
from extensions.chaeo.zmask import build_zmask_from_object_mask, project_stack_from_focal_points
from extensions.ilastik.models import IlastikPixelClassifierModel
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import Model
from model_server.process import rescale
from model_server.workflows import Timer
def get_zmask_meta(
input_file_path: str,
ilastik_pixel_classifier: IlastikPixelClassifierModel,
segmentation_channel: int,
pxmap_threshold: float,
pxmap_foreground_channel: int = 0,
zmask_zindex: int = None,
zmask_clip: int = None,
zmask_expand_box_by: int = None,
zmask_filters: Dict = None,
zmask_type: str = 'boxes',
) -> tuple:
ti = Timer()
stack = generate_file_accessor(Path(input_file_path))
fstem = Path(input_file_path).stem
ti.click('file_input')
# MIP if no zmask z-index is given, then classify pixels
if isinstance(zmask_zindex, int):
assert 0 < zmask_zindex < stack.nz
zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex]
else:
zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True)
if zmask_clip:
zmask_data = rescale(zmask_data, zmask_clip)
mip = InMemoryDataAccessor(
zmask_data,
)
pxmap, _ = ilastik_pixel_classifier.infer(mip)
ti.click('infer_pixel_probability')
obmask = InMemoryDataAccessor(
pxmap.data > pxmap_threshold
)
ti.click('threshold_pixel_mask')
# make zmask
zmask, zmask_meta, df, interm = build_zmask_from_object_mask(
obmask.get_one_channel_data(pxmap_foreground_channel),
stack.get_one_channel_data(segmentation_channel),
mask_type=zmask_type,
filters=zmask_filters,
expand_box_by=zmask_expand_box_by,
)
ti.click('generate_zmasks')
return ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm
# TODO: unpack and validate inputs
def export_patches_from_multichannel_zstack(
input_file_path: str,
output_folder_path: str,
models: List[Model],
pxmap_threshold: float,
pxmap_foreground_channel: int,
segmentation_channel: int,
patches_channel: int,
zmask_zindex: int = None, # None for MIP,
zmask_clip: int = None,
zmask_type: str = 'boxes',
zmask_filters: Dict = None,
zmask_expand_box_by: int = None,
export_pixel_probabilities=True,
export_2d_patches_for_training=True,
export_2d_patches_for_annotation=True,
draw_bounding_box_on_2d_patch=True,
draw_contour_on_2d_patch=False,
draw_mask_on_2d_patch=False,
export_3d_patches=True,
export_annotated_zstack=True,
export_patch_masks=True,
rgb_overlay_channels=(None, None, None),
rgb_overlay_weights=(1.0, 1.0, 1.0),
) -> Dict:
pixel_classifier = models[0]
ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm = get_zmask_meta(
input_file_path,
pixel_classifier,
segmentation_channel,
pxmap_threshold,
pxmap_foreground_channel=pxmap_foreground_channel,
zmask_zindex=zmask_zindex,
zmask_clip=zmask_clip,
zmask_expand_box_by=zmask_expand_box_by,
zmask_filters=zmask_filters,
zmask_type=zmask_type,
)
if export_pixel_probabilities:
write_accessor_data_to_file(
Path(output_folder_path) / 'pixel_probabilities' / (fstem + '.tif'),
pxmap
)
ti.click('export_pixel_probability')
if export_3d_patches:
files = export_patches_from_zstack(
Path(output_folder_path) / '3d_patches',
stack.get_one_channel_data(patches_channel),
zmask_meta,
prefix=fstem,
draw_bounding_box=False,
rescale_clip=0.001,
make_3d=True,
)
ti.click('export_3d_patches')
if export_2d_patches_for_annotation:
files = export_multichannel_patches_from_zstack(
Path(output_folder_path) / '2d_patches_annotation',
stack,
zmask_meta,
prefix=fstem,
rescale_clip=0.001,
make_3d=False,
focus_metric='max_sobel',
ch_white=patches_channel,
ch_rgb_overlay=rgb_overlay_channels,
draw_bounding_box=draw_bounding_box_on_2d_patch,
bounding_box_channel=1,
bounding_box_linewidth=2,
draw_contour=draw_contour_on_2d_patch,
draw_mask=draw_mask_on_2d_patch,
overlay_gain=rgb_overlay_weights,
)
df_patches = pd.DataFrame(files)
ti.click('export_2d_patches')
# associate 2d patches, dropping labeled objects that were not exported as patches
df = pd.merge(df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
# prepopulate patch UUID
df['patch_id'] = df.apply(lambda _: uuid4(), axis=1)
if export_2d_patches_for_training:
files = export_multichannel_patches_from_zstack(
Path(output_folder_path) / '2d_patches_training',
stack.get_one_channel_data(patches_channel),
zmask_meta,
prefix=fstem,
rescale_clip=0.001,
make_3d=False,
focus_metric='max_sobel',
)
df_patches = pd.DataFrame(files)
ti.click('export_2d_patches')
if export_patch_masks:
files = export_patch_masks_from_zstack(
Path(output_folder_path) / 'patch_masks',
stack.get_one_channel_data(patches_channel),
zmask_meta,
prefix=fstem,
)
if export_annotated_zstack:
annotated = InMemoryDataAccessor(
draw_boxes_on_3d_image(
stack.get_one_channel_data(patches_channel).data,
zmask_meta
)
)
write_accessor_data_to_file(
Path(output_folder_path) / 'annotated_zstacks' / (fstem + '.tif'),
annotated
)
ti.click('export_annotated_zstack')
# generate multichannel projection from label centroids
dff = df[df['keeper']]
interm['projected'] = project_stack_from_focal_points(
dff['centroid-0'].to_numpy(),
dff['centroid-1'].to_numpy(),
dff['zi'].to_numpy(),
stack,
degree=4,
)
return {
'pixel_model_id': pixel_classifier.model_id,
'input_filepath': input_file_path,
'number_of_objects': len(zmask_meta),
'pixeL_scale_in_micrometers': stack.pixel_scale_in_micrometers,
'success': True,
'timer_results': ti.events,
'dataframe': df,
'interm': interm,
}
def get_object_map_from_zstack(
input_file_path: str,
output_folder_path: str,
models: List[Model],
pxmap_threshold: float,
pxmap_foreground_channel: int,
segmentation_channel: int,
patches_channel: int,
zmask_zindex: int = None, # None for MIP,
zmask_clip: int = None,
zmask_type: str = 'boxes',
zmask_filters: Dict = None,
zmask_expand_box_by: int = None,
**kwargs,
) -> Dict:
assert len(models) == 2
pixel_classifier = models[0]
assert isinstance(pixel_classifier, IlastikPixelClassifierModel)
object_classifier = models[1]
assert isinstance(object_classifier, PatchStackObjectClassifier)
ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm = get_zmask_meta(
input_file_path,
pixel_classifier,
segmentation_channel,
pxmap_threshold,
pxmap_foreground_channel=pxmap_foreground_channel,
zmask_zindex=zmask_zindex,
zmask_clip=zmask_clip,
zmask_expand_box_by=zmask_expand_box_by,
zmask_filters=zmask_filters,
zmask_type=zmask_type,
)
# extract patches to accessor
patches_acc = get_patches_from_zmask_meta(
stack,
zmask_meta,
rescale_clip=zmask_clip,
make_3d=False,
**kwargs
)
# extract masks
patch_masks_acc = get_patch_masks_from_zmask_meta(
stack,
zmask_meta,
**kwargs
)
# send patches and mask stacks to object classifier
result_acc, _ = MonoPatchStack(
object_classifier.infer(patches_acc, patch_masks_acc)
)
object_labels_map = np.copy(interm['label_map'])
assert object_labels_map.shape == interm['label_map'].shape
assert object_labels_map.dtype == interm['label_map'].dtype
# assign labels to object map:
for ii in range(0, len(zmask_meta)):
mi = zmask_meta[ii]
object_label_id = mi['info'].label
result_label_map = result_acc.iat(ii)
unique_values = np.unique(result_label_map)
assert len(unique_values) == 2
assert unique_values[0] == 0
inferred_class = result_acc.iat(ii)
ii_mask = object_labels_map == object_label_id
object_labels_map[ii_mask] = unique_values[1]
patch = patches_acc.iat(ii)
def transfer_ecotaxa_labels_to_patch_stacks(
where_masks: str,
where_patches: str,
object_csv: str,
ecotaxa_tsv: str,
where_output: str,
patch_size: tuple = (256, 256),
tr_split=0.6,
dilate_label_mask: bool = True, # to mitigate connected components error in ilastik
) -> Dict:
assert tr_split > 0.5 # reduce chance that low-probability objects are omitted from training
# read patch metadata
df_obj = pd.read_csv(
object_csv,
)
df_ecotaxa = pd.read_csv(
ecotaxa_tsv,
sep='\t',
header=[0],
dtype={
('object_annotation_date', '[t]'): str,
('object_annotation_time', '[t]'): str,
('object_annotation_category_id', '[t]'): str,
}
)
df_merge = pd.merge(df_obj, df_ecotaxa, left_on='patch_id', right_on='object_id')
# assign each unique lowest-level annotation to a class index
se_unique = pd.Series(
df_merge.object_annotation_hierarchy.unique()
)
df_split = (
se_unique.str.rsplit(
pat='>', n=1, expand=True
)
)
df_labels = pd.DataFrame({
'annotation_class_id': df_split.index + 1,
'hierarchy': se_unique,
'annotation_class': df_split.loc[:, 1].str.lower()
})
# join patch filenames and annotation classes
df_pf = pd.merge(
df_merge[['patch_filename', 'object_annotation_hierarchy']],
df_labels,
left_on='object_annotation_hierarchy',
right_on='hierarchy',
)
df_pl = df_pf[df_pf['object_annotation_hierarchy'].notnull()]
# export annotation classes and their summary stats
df_tr, df_te = train_test_split(df_pl, train_size=tr_split)
# df_labels['counts'] = df_pl['annotation_class_id'].value_counts()
df_labels = pd.merge(
df_labels,
pd.DataFrame(
[df_pl.annotation_class_id.value_counts(), df_tr.annotation_class_id.value_counts(), df_te.annotation_class_id.value_counts()],
index=['total', 'to_train', 'to_test']
).T,
left_on='annotation_class_id',
right_index=True,
how='outer'
)
df_labels.loc[df_labels.to_train.isna(), 'to_train'] = 0
df_labels.loc[df_labels.to_test.isna(), 'to_test'] = 0
for col in ['total', 'to_train', 'to_test']:
df_labels.loc[df_labels[col].isna(), col] = 0
df_labels.to_csv(Path(where_output) / 'labels_key.csv', index=False)
# export patches as z-stacks
for (dfk, dfv) in {'train': df_tr, 'test': df_te}.items():
zstack_keys = ['mask', 'label', 'raw']
zstacks = {f'{dfk}_{zsk}': np.zeros((*patch_size, 1, len(dfv)), dtype='uint8') for zsk in zstack_keys}
stack_meta = []
for fi, pl in enumerate(dfv.itertuples(name='PatchFile')):
fn = pl._asdict()['patch_filename']
ac = pl._asdict()['annotation_class']
aci = pl._asdict()['annotation_class_id']
stack_meta.append({'zi': fi, 'patch_filename': fn, 'annotation_class': ac, 'annotation_class_id': aci})
acc_bm = generate_file_accessor(Path(where_masks) / fn)
assert acc_bm.is_mask()
assert acc_bm.hw == patch_size, f'Unexpected patch size {patch_size}'
assert acc_bm.chroma == 1
assert acc_bm.nz == 1
mask = acc_bm.data[:, :, 0, 0]
if dilate_label_mask:
mask = dilation(mask)
zstacks[dfk + '_mask'][:, :, 0, fi] = mask
zstacks[dfk + '_label'][:, :, 0, fi] = (mask == 255) * aci
acc_pa = generate_file_accessor(Path(where_patches) / fn)
zstacks[dfk + '_raw'][:, :, :, fi] = acc_pa.data[:, :, :, 0]
for k in zstacks.keys():
write_accessor_data_to_file(Path(where_output) / f'zstack_{k}.tif', InMemoryDataAccessor(zstacks[k]))
pd.DataFrame(stack_meta).to_csv(Path(where_output) / f'{dfk}_stack.csv', index=False)