Skip to content
Snippets Groups Projects
workflows.py 7.71 KiB
from pathlib import Path
from typing import Dict
from uuid import uuid4

import numpy as np
import pandas as pd

from extensions.ilastik.models import IlastikPixelClassifierModel
from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack
from extensions.chaeo.zmask import build_zmask_from_object_mask, project_stack_from_focal_points
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.workflows import Timer

# TODO: unpack and validate inputs
# TODO: expose channel indices and color balance vectors to caller
def export_patches_from_multichannel_zstack(
        input_zstack_path: str,
        ilastik_project_file: str,
        pxmap_threshold: float,
        pixel_class: int,
        zmask_channel: int,
        patches_channel: int,
        where_output: str,
        mask_type: str = 'boxes',
        zmask_filters: Dict = None,
        zmask_expand_box_by: int = None,
        export_pixel_probabilities=True,
        export_2d_patches_for_training=True,
        export_2d_patches_for_annotation=True,
        export_3d_patches=True,
        export_annotated_zstack=True,
        export_patch_masks=True,
) -> Dict:
    ti = Timer()
    stack = generate_file_accessor(Path(input_zstack_path))
    fstem = Path(input_zstack_path).stem
    ti.click('file_input')
    assert stack.nz > 1, 'Expecting z-stack'

    # MIP and classify pixels
    mip = InMemoryDataAccessor(
        stack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True)
    )
    px_model = IlastikPixelClassifierModel(
        params={'project_file': Path(ilastik_project_file)}
    )
    pxmap, _ = px_model.infer(mip)
    ti.click('infer_pixel_probability')

    if export_pixel_probabilities:
        write_accessor_data_to_file(
            Path(where_output) / 'pixel_probabilities' / (fstem + '.tif'),
            pxmap
        )
        ti.click('export_pixel_probability')

    obmask = InMemoryDataAccessor(
        pxmap.data > pxmap_threshold
    )
    ti.click('threshold_pixel_mask')

    # make zmask
    zmask, zmask_meta, df, interm = build_zmask_from_object_mask(
        obmask.get_one_channel_data(pixel_class),
        stack.get_one_channel_data(zmask_channel),
        mask_type=mask_type,
        filters=zmask_filters,
        expand_box_by=zmask_expand_box_by,
    )
    zmask_acc = InMemoryDataAccessor(zmask)
    ti.click('generate_zmasks')

    if export_3d_patches:
        files = export_patches_from_zstack(
            Path(where_output) / '3d_patches',
            stack.get_one_channel_data(patches_channel),
            zmask_meta,
            prefix=fstem,
            draw_bounding_box=False,
            rescale_clip=0.0,
            make_3d=True,
        )
        ti.click('export_3d_patches')

    if export_2d_patches_for_annotation:
        files = export_multichannel_patches_from_zstack(
            Path(where_output) / '2d_patches_annotation',
            stack,
            zmask_meta,
            prefix=fstem,
            rescale_clip=0.001,
            make_3d=False,
            focus_metric='max_sobel',
            ch_white=4,
            ch_rgb_overlay=(3, None, None),
            draw_bounding_box=False,
            bounding_box_channel=1,
            bounding_box_linewidth=2,
            # draw_contour=True,
            draw_mask=True,
            overlay_gain=(0.1, 1.0, 1.0)
        )
        df_patches = pd.DataFrame(files)
        ti.click('export_2d_patches')
        # associate 2d patches, dropping labeled objects that were not exported as patches
        df = pd.merge(df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
        # prepopulate patch UUID
        df['patch_id'] = df.apply(lambda _: uuid4(), axis=1)

    if export_2d_patches_for_training:
        files = export_multichannel_patches_from_zstack(
            Path(where_output) / '2d_patches_training',
            stack.get_one_channel_data(4),
            zmask_meta,
            prefix=fstem,
            rescale_clip=0.001,
            make_3d=False,
            focus_metric='max_sobel',
        )
        df_patches = pd.DataFrame(files)
        ti.click('export_2d_patches')
        # associate 2d patches, dropping labeled objects that were not exported as patches
        df = pd.merge(df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
        # prepopulate patch UUID
        df['patch_id'] = df.apply(lambda _: uuid4(), axis=1)


    if export_patch_masks:
        files = export_patch_masks_from_zstack(
            Path(where_output) / 'patch_masks',
            stack.get_one_channel_data(4),
            zmask_meta,
            prefix=fstem,
        )

    if export_annotated_zstack:
        annotated = InMemoryDataAccessor(
            draw_boxes_on_3d_image(
                stack.get_one_channel_data(patches_channel).data,
                zmask_meta
            )
        )
        write_accessor_data_to_file(
            Path(where_output) / 'annotated_zstacks' / (fstem + '.tif'),
            annotated
        )
        ti.click('export_annotated_zstack')

    # generate multichannel projection from label centroids
    dff = df[df['keeper']]
    interm['projected'] = project_stack_from_focal_points(
        dff['centroid-0'].to_numpy(),
        dff['centroid-1'].to_numpy(),
        dff['zi'].to_numpy(),
        stack,
        degree=4,
    )

    return {
        'pixel_model_id': px_model.model_id,
        'input_filepath': input_zstack_path,
        'number_of_objects': len(zmask_meta),
        'success': True,
        'timer_results': ti.events,
        'dataframe': df,
        'interm': interm,
    }

def transfer_ecotaxa_labels_to_patch_stacks(
        where_masks: str,
        object_csv: str,
        ecotaxa_tsv: str,
        where_output: str,
        patch_size: tuple = (256, 256),
) -> Dict:
    df_obj = pd.read_csv(
        object_csv,
    )
    df_ecotaxa = pd.read_csv(
        ecotaxa_tsv,
        sep='\t',
        header=[0],
        dtype={
            ('object_annotation_date', '[t]'): str,
            ('object_annotation_time', '[t]'): str,
            ('object_annotation_category_id', '[t]'): str,
        }
    )
    df_merge = pd.merge(df_obj, df_ecotaxa, left_on='patch_id', right_on='object_id')
    se_unique = pd.Series(
        df_merge.object_annotation_hierarchy.unique()
    )
    df_split = (
        se_unique.str.rsplit(
            pat='>', n=1, expand=True
        )
    )
    df_labels = pd.DataFrame({
        'annotation_class_id': df_split.index,
        'hierarchy': se_unique,
        'annotation_class': df_split.loc[:, 1].str.lower()
    })

    df_pf = pd.merge(
        df_merge[['patch_filename', 'object_annotation_hierarchy']],
        df_labels,
        left_on='object_annotation_hierarchy',
        right_on='hierarchy',
    )
    df_pl = df_pf[df_pf['object_annotation_hierarchy'].notnull()]

    zstack = np.zeros((*patch_size, 1, len(df_pl)), dtype='uint8')

    df_labels['counts'] = df_pl['annotation_class_id'].value_counts()
    df_labels.to_csv(Path(where_output) / 'labels_key.csv')

    # export patches as z-stack
    for fi, pl in enumerate(df_pl.itertuples(name='PatchFile')):
        fn = pl._asdict()['patch_filename']
        ac = pl._asdict()['annotation_class_id']
        acc_bm = generate_file_accessor(Path(where_masks) / fn)
        assert acc_bm.hw == patch_size, f'Unexpected patch size {patch_size}'
        assert acc_bm.chroma == 1
        assert acc_bm.nz == 1
        zstack[:, :, 0, fi] = (acc_bm.data[:, :, 0, 0] == 255) * ac

    # export masks as z-stack
    write_accessor_data_to_file(Path(where_output) / 'zstack_object_label.tif', InMemoryDataAccessor(zstack))