Skip to content
Snippets Groups Projects
workflows.py 14.5 KiB
Newer Older
from skimage.measure import label, regionprops_table
from skimage.morphology import dilation
from sklearn.model_selection import train_test_split
from extensions.chaeo.accessors import MonoPatchStack
from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.models import PatchStackObjectClassifier
from extensions.chaeo.process import mask_largest_object
from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
from extensions.chaeo.zmask import build_zmask_from_object_mask, project_stack_from_focal_points
from extensions.ilastik.models import IlastikPixelClassifierModel

from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
def get_zmask_meta(
    input_file_path: str,
    ilastik_pixel_classifier: IlastikPixelClassifierModel,
    segmentation_channel: int,
    pxmap_threshold: float,
    pxmap_foreground_channel: int = 0,
    zmask_zindex: int = None,
    zmask_clip: int = None,
    zmask_filters: Dict = None,
    zmask_type: str = 'boxes',
    ti = Timer()
    stack = generate_file_accessor(Path(input_file_path))
    fstem = Path(input_file_path).stem
    ti.click('file_input')

    # MIP if no zmask z-index is given, then classify pixels
    if isinstance(zmask_zindex, int):
        assert 0 < zmask_zindex < stack.nz
        zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex]
    else:
        zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True)
    if zmask_clip:
        zmask_data = rescale(zmask_data, zmask_clip)
    mip = InMemoryDataAccessor(
        zmask_data,
    )
    pxmap, _ = ilastik_pixel_classifier.infer(mip)
    ti.click('infer_pixel_probability')

    obmask = InMemoryDataAccessor(
        pxmap.data > pxmap_threshold
    )
    ti.click('threshold_pixel_mask')

    # make zmask
    zmask, zmask_meta, df, interm = build_zmask_from_object_mask(
        obmask.get_one_channel_data(pxmap_foreground_channel),
        stack.get_one_channel_data(segmentation_channel),
        mask_type=zmask_type,
        filters=zmask_filters,
    # record pixel scale
    df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X'))

    return ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm
# TODO: unpack and validate inputs
def export_patches_from_multichannel_zstack(
        zmask_filters: Dict = None,
        zmask_expand_box_by: int = None,
        export_pixel_probabilities=True,
        export_2d_patches_for_training=True,
        export_2d_patches_for_annotation=True,
        draw_bounding_box_on_2d_patch=True,
        draw_contour_on_2d_patch=False,
        draw_mask_on_2d_patch=False,
        export_3d_patches=True,
        export_annotated_zstack=True,
        draw_label_on_zstack=False,
        export_patch_masks=True,
        rgb_overlay_channels=(None, None, None),
        rgb_overlay_weights=(1.0, 1.0, 1.0),
    ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm = get_zmask_meta(
        input_file_path,
        pixel_classifier,
        segmentation_channel,
        pxmap_threshold,
        pxmap_foreground_channel=pxmap_foreground_channel,
        zmask_zindex=zmask_zindex,
        zmask_clip=zmask_clip,
        zmask_expand_box_by=zmask_expand_box_by,
        zmask_filters=zmask_filters,
        zmask_type=zmask_type,
    if export_pixel_probabilities:
        write_accessor_data_to_file(
            Path(output_folder_path) / 'pixel_probabilities' / (fstem + '.tif'),
            pxmap
        )
        ti.click('export_pixel_probability')
    if export_3d_patches and len(zmask_meta) > 0:
        files = export_patches_from_zstack(
            stack.get_one_channel_data(patches_channel),
            zmask_meta,
            prefix=fstem,
            draw_bounding_box=False,
            make_3d=True,
        )
        ti.click('export_3d_patches')
    if export_2d_patches_for_annotation and len(zmask_meta) > 0:
        files = export_multichannel_patches_from_zstack(
            Path(output_folder_path) / '2d_patches_annotation',
            zmask_meta,
            prefix=fstem,
            make_3d=False,
            focus_metric='max_sobel',
            draw_bounding_box=draw_bounding_box_on_2d_patch,
            bounding_box_channel=1,
            draw_contour=draw_contour_on_2d_patch,
            draw_mask=draw_mask_on_2d_patch,
        ti.click('export_2d_patches')
        # associate 2d patches, dropping labeled objects that were not exported as patches
        df = pd.merge(df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
        # prepopulate patch UUID
        df['patch_id'] = df.apply(lambda _: uuid4(), axis=1)
    if export_2d_patches_for_training and len(zmask_meta) > 0:
        files = export_multichannel_patches_from_zstack(
            Path(output_folder_path) / '2d_patches_training',
            stack.get_one_channel_data(patches_channel),
            zmask_meta,
            prefix=fstem,
            rescale_clip=0.001,
            make_3d=False,
            focus_metric='max_sobel',
        )
        ti.click('export_2d_patches')
    if export_patch_masks and len(zmask_meta) > 0:
        files = export_patch_masks_from_zstack(
            Path(output_folder_path) / 'patch_masks',
            stack.get_one_channel_data(patches_channel),
            zmask_meta,
    if export_annotated_zstack:
        annotated = InMemoryDataAccessor(
            draw_boxes_on_3d_image(
                stack.get_one_channel_data(patches_channel).data,
                zmask_meta,
                add_label=draw_label_on_zstack,
            )
        )
        write_accessor_data_to_file(
            Path(output_folder_path) / 'annotated_zstacks' / (fstem + '.tif'),
            annotated
        )
        ti.click('export_annotated_zstack')

    # generate multichannel projection from label centroids
    dff = df[df['keeper']]
    if len(zmask_meta) > 0:
        interm['projected'] = project_stack_from_focal_points(
            dff['centroid-0'].to_numpy(),
            dff['centroid-1'].to_numpy(),
            dff['zi'].to_numpy(),
            stack,
            degree=4,
        )
    else: # else just return MIP
        interm['projected'] = stack.data.max(axis=-1)
        'pixeL_scale_in_micrometers': stack.pixel_scale_in_micrometers,
        'success': True,
        'timer_results': ti.events,
        'dataframe': df[df['keeper'] == True],
        input_file_path: str,
        output_folder_path: str,
        models: List[Model],
        pxmap_threshold: float,
        pxmap_foreground_channel: int,
        segmentation_channel: int,
        patches_channel: int,
        zmask_zindex: int = None,  # None for MIP,
        zmask_clip: int = None,
        zmask_type: str = 'boxes',
        zmask_filters: Dict = None,
    assert isinstance(pixel_classifier, IlastikPixelClassifierModel)
    object_classifier = models[1]
    assert isinstance(object_classifier, PatchStackObjectClassifier)

    ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm = get_zmask_meta(
        input_file_path,
        pixel_classifier,
        segmentation_channel,
        pxmap_threshold,
        pxmap_foreground_channel=pxmap_foreground_channel,
        zmask_zindex=zmask_zindex,
        zmask_clip=zmask_clip,
        zmask_filters=zmask_filters,
        zmask_type=zmask_type,
    )

    # extract patches to accessor
    patches_acc = get_patches_from_zmask_meta(
        **kwargs
    )

    # extract masks
    patch_masks_acc = get_patch_masks_from_zmask_meta(
        stack,
        zmask_meta,
        **kwargs
    )

    # send patches and mask stacks to object classifier
    result_acc, _ = object_classifier.infer(patches_acc, patch_masks_acc)
    labels_map = interm['label_map']
    output_map = np.zeros(labels_map.shape, dtype=labels_map.dtype)
    assert labels_map.shape == interm['label_map'].shape
    assert labels_map.dtype == interm['label_map'].dtype
        object_id = zmask_meta[ii]['info'].label
        result_patch = mask_largest_object(result_acc.iat(ii))
        object_class = np.unique(result_patch)[1]
        output_map[labels_map == object_id] = object_class
        meta.append({'object_id': ii, 'object_class': object_id})
    output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif'))

    return {
        'timer_results': ti.events,
        'dataframe': pd.DataFrame(meta),
        'interm': {},
def transfer_ecotaxa_labels_to_patch_stacks(
    where_masks: str,
    where_patches: str,
    object_csv: str,
    ecotaxa_tsv: str,
    where_output: str,
    patch_size: tuple = (256, 256),
    tr_split=0.6,
    dilate_label_mask: bool = True, # to mitigate connected components error in ilastik
    allow_multiple_objects: bool = False,
    assert tr_split > 0.5 # reduce chance that low-probability objects are omitted from training

    # read patch metadata
    df_obj = pd.read_csv(
        object_csv,
    )
    df_ecotaxa = pd.read_csv(
        ecotaxa_tsv,
        dtype={
            ('object_annotation_date', '[t]'): str,
            ('object_annotation_time', '[t]'): str,
            ('object_annotation_category_id', '[t]'): str,
        }
    )
    df_merge = pd.merge(df_obj, df_ecotaxa, left_on='patch_id', right_on='object_id')

    # assign each unique lowest-level annotation to a class index
    se_unique = pd.Series(
        df_merge.object_annotation_hierarchy.unique()
    )
    df_split = (
        se_unique.str.rsplit(
            pat='>', n=1, expand=True
        )
    )
    df_labels = pd.DataFrame({
        'annotation_class_id': df_split.index + 1,
        'hierarchy': se_unique,
        'annotation_class': df_split.loc[:, 1].str.lower()
    })

    # join patch filenames and annotation classes
    df_pf = pd.merge(
        df_merge[['patch_filename', 'object_annotation_hierarchy']],
        df_labels,
        left_on='object_annotation_hierarchy',
        right_on='hierarchy',
    )
    df_pl = df_pf[df_pf['object_annotation_hierarchy'].notnull()]

    # export annotation classes and their summary stats
    df_tr, df_te = train_test_split(df_pl, train_size=tr_split)
    # df_labels['counts'] = df_pl['annotation_class_id'].value_counts()
    df_labels = pd.merge(
        df_labels,
        pd.DataFrame(
            [df_pl.annotation_class_id.value_counts(), df_tr.annotation_class_id.value_counts(), df_te.annotation_class_id.value_counts()],
            index=['total', 'to_train', 'to_test']
        ).T,
        left_on='annotation_class_id',
        right_index=True,
        how='outer'
    )
    df_labels.loc[df_labels.to_train.isna(), 'to_train'] = 0
    df_labels.loc[df_labels.to_test.isna(), 'to_test'] = 0
        df_labels.loc[df_labels[col].isna(), col] = 0
    df_labels.to_csv(Path(where_output) / 'labels_key.csv', index=False)

    # export patches as z-stacks
    for (dfk, dfv) in {'train': df_tr, 'test': df_te}.items():
        zstack_keys = ['mask', 'label', 'raw']
        zstacks = {f'{dfk}_{zsk}': np.zeros((*patch_size, 1, len(dfv)), dtype='uint8') for zsk in zstack_keys}
        stack_meta = []
        for fi, pl in enumerate(dfv.itertuples(name='PatchFile')):
            fn = pl._asdict()['patch_filename']
            ac = pl._asdict()['annotation_class']
            aci = pl._asdict()['annotation_class_id']
            stack_meta.append({'zi': fi, 'patch_filename': fn, 'annotation_class': ac, 'annotation_class_id': aci})
            acc_bm = generate_file_accessor(Path(where_masks) / fn)
            assert acc_bm.is_mask()
            assert acc_bm.hw == patch_size, f'Unexpected patch size {patch_size}'
            assert acc_bm.chroma == 1
            assert acc_bm.nz == 1
            mask = acc_bm.data[:, :, 0, 0]
            if dilate_label_mask:
                mask = dilation(mask)
            if not allow_multiple_objects:
                ob_id = label(acc_bm.data[:, :, 0, 0])
                mask = mask_largest_object(ob_id)
            zstacks[dfk + '_mask'][:, :, 0, fi] = mask
            zstacks[dfk + '_label'][:, :, 0, fi] = (mask == 255) * aci
            acc_pa = generate_file_accessor(Path(where_patches) / fn)
            zstacks[dfk + '_raw'][:, :, :, fi] = acc_pa.data[:, :, :, 0]
        for k in zstacks.keys():
            write_accessor_data_to_file(Path(where_output) / f'zstack_{k}.tif', InMemoryDataAccessor(zstacks[k]))
        pd.DataFrame(stack_meta).to_csv(Path(where_output) / f'{dfk}_stack.csv', index=False)