Skip to content
Snippets Groups Projects
workflows.py 18.11 KiB
from pathlib import Path
from typing import Dict, List
from uuid import uuid4

import numpy as np
import pandas as pd

from skimage.measure import label, regionprops_table
from skimage.morphology import dilation
from sklearn.model_selection import train_test_split

from extensions.chaeo.accessors import MonoPatchStack
from extensions.chaeo.annotators import draw_boxes_on_3d_image
from extensions.chaeo.models import PatchStackObjectClassifier
from extensions.chaeo.params import RoiSetExportParams
from extensions.chaeo.process import mask_largest_object
from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
from extensions.chaeo.zmask import project_stack_from_focal_points, RoiSet
from extensions.ilastik.models import IlastikPixelClassifierModel

from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
from model_server.process import rescale
from model_server.workflows import Timer

# def get_zmask_meta(
#     input_file_path: str,
#     ilastik_pixel_classifier: IlastikPixelClassifierModel,
#     segmentation_channel: int,
#     pxmap_threshold: float,
#     pxmap_foreground_channel: int = 0,
#     zmask_zindex: int = None,
#     zmask_clip: int = None,
#     zmask_filters: Dict = None,
#     zmask_type: str = 'boxes',
#     **kwargs,
# ) -> tuple:
#     ti = Timer()
#     stack = generate_file_accessor(Path(input_file_path))
#     fstem = Path(input_file_path).stem
#     ti.click('file_input')
#
#     # MIP if no zmask z-index is given, then classify pixels
#     if isinstance(zmask_zindex, int):
#         assert 0 < zmask_zindex < stack.nz
#         zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex]
#     else:
#         zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True)
#     if zmask_clip:
#         zmask_data = rescale(zmask_data, zmask_clip)
#     mip = InMemoryDataAccessor(
#         zmask_data,
#     )
#     pxmap, _ = ilastik_pixel_classifier.infer(mip)
#     ti.click('infer_pixel_probability')
#
#     obmask = InMemoryDataAccessor(
#         pxmap.data > pxmap_threshold
#     )
#     ti.click('threshold_pixel_mask')
#
#     # make zmask
#     obj_table = ZMaskObjectTable(
#         obmask.get_one_channel_data(pxmap_foreground_channel),
#         stack.get_one_channel_data(segmentation_channel),
#         mask_type=zmask_type,
#         filters=zmask_filters,
#         expand_box_by=kwargs['zmask_expand_box_by'],
#     )
#     ti.click('generate_zmasks')
#
#     # record pixel scale
#     obj_table.df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X'))
#
#     return ti, stack, fstem, obmask, pxmap, obj_table


# # called by batch runners
# def export_patches_from_multichannel_zstack(
#         input_file_path: str,
#         output_folder_path: str,
#         models: List[Model],
#         pxmap_threshold: float,
#         pxmap_foreground_channel: int,
#         segmentation_channel: int,
#         patches_channel: int,
#         zmask_zindex: int = None,  # None for MIP,
#         zmask_clip: int = None,
#         zmask_type: str = 'boxes',
#         zmask_filters: Dict = None,
#         zmask_expand_box_by: int = None,
#         export_pixel_probabilities=True,
#         export_2d_patches_for_training=True,
#         export_2d_patches_for_annotation=True,
#         draw_bounding_box_on_2d_patch=True,
#         draw_contour_on_2d_patch=False,
#         draw_mask_on_2d_patch=False,
#         export_3d_patches=True,
#         export_annotated_zstack=True,
#         draw_label_on_zstack=False,
#         export_patch_masks=True,
#         rgb_overlay_channels=(None, None, None),
#         rgb_overlay_weights=(1.0, 1.0, 1.0),
# ) -> Dict:
#     pixel_classifier = models[0]
#
#     # ti, stack, fstem, obmask, pxmap, obj_table = get_zmask_meta(
#     #     input_file_path,
#     #     pixel_classifier,
#     #     segmentation_channel,
#     #     pxmap_threshold,
#     #     pxmap_foreground_channel=pxmap_foreground_channel,
#     #     zmask_zindex=zmask_zindex,
#     #     zmask_clip=zmask_clip,
#     #     zmask_expand_box_by=zmask_expand_box_by,
#     #     zmask_filters=zmask_filters,
#     #     zmask_type=zmask_type,
#     # )
#
#     # obj_table = ZMaskObjectTable(
#     #     obmask.get_one_channel_data(pxmap_foreground_channel),
#     #     stack.get_one_channel_data(segmentation_channel),
#     #     mask_type=zmask_type,
#     #     filters=zmask_filters,
#     #     expand_box_by=kwargs['zmask_expand_box_by'],
#     # )
#
#     if export_pixel_probabilities:
#         write_accessor_data_to_file(
#             Path(output_folder_path) / 'pixel_probabilities' / (fstem + '.tif'),
#             pxmap
#         )
#         ti.click('export_pixel_probability')
#
#     if export_3d_patches and len(zmask_meta) > 0:
#         files = export_patches_from_zstack(
#             Path(output_folder_path) / '3d_patches',
#             stack.get_one_channel_data(patches_channel),
#             zmask_meta,
#             prefix=fstem,
#             draw_bounding_box=False,
#             rescale_clip=0.001,
#             make_3d=True,
#         )
#         ti.click('export_3d_patches')
#
#     if export_2d_patches_for_annotation and len(zmask_meta) > 0:
#         files = export_multichannel_patches_from_zstack(
#             Path(output_folder_path) / '2d_patches_annotation',
#             stack,
#             zmask_meta,
#             prefix=fstem,
#             rescale_clip=0.001,
#             make_3d=False,
#             focus_metric='max_sobel',
#             ch_white=patches_channel,
#             ch_rgb_overlay=rgb_overlay_channels,
#             draw_bounding_box=draw_bounding_box_on_2d_patch,
#             bounding_box_channel=1,
#             bounding_box_linewidth=2,
#             draw_contour=draw_contour_on_2d_patch,
#             draw_mask=draw_mask_on_2d_patch,
#             overlay_gain=rgb_overlay_weights,
#         )
#         df_patches = pd.DataFrame(files)
#         ti.click('export_2d_patches')
#         # associate 2d patches, dropping labeled objects that were not exported as patches
#         df = pd.merge(df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
#         # prepopulate patch UUID
#         df['patch_id'] = df.apply(lambda _: uuid4(), axis=1)
#
#     if export_2d_patches_for_training and len(zmask_meta) > 0:
#         files = export_multichannel_patches_from_zstack(
#             Path(output_folder_path) / '2d_patches_training',
#             stack.get_one_channel_data(patches_channel),
#             zmask_meta,
#             prefix=fstem,
#             rescale_clip=0.001,
#             make_3d=False,
#             focus_metric='max_sobel',
#         )
#         ti.click('export_2d_patches')
#
#     if export_patch_masks and len(zmask_meta) > 0:
#         files = export_patch_masks_from_zstack(
#             Path(output_folder_path) / 'patch_masks',
#             stack.get_one_channel_data(patches_channel),
#             zmask_meta,
#             prefix=fstem,
#         )
#
#     if export_annotated_zstack:
#         annotated = InMemoryDataAccessor(
#             draw_boxes_on_3d_image(
#                 stack.get_one_channel_data(patches_channel).data,
#                 zmask_meta,
#                 add_label=draw_label_on_zstack,
#             )
#         )
#         write_accessor_data_to_file(
#             Path(output_folder_path) / 'annotated_zstacks' / (fstem + '.tif'),
#             annotated
#         )
#         ti.click('export_annotated_zstack')
#
#     # generate multichannel projection from label centroids
#     dff = df[df['keeper']]
#     if len(zmask_meta) > 0:
#         interm['projected'] = project_stack_from_focal_points(
#             dff['centroid-0'].to_numpy(),
#             dff['centroid-1'].to_numpy(),
#             dff['zi'].to_numpy(),
#             stack,
#             degree=4,
#         )
#     else: # else just return MIP
#         interm['projected'] = stack.data.max(axis=-1)
#
#     return {
#         'pixel_model_id': pixel_classifier.model_id,
#         'input_filepath': input_file_path,
#         'number_of_objects': len(zmask_meta),
#         'pixeL_scale_in_micrometers': stack.pixel_scale_in_micrometers,
#         'success': True,
#         'timer_results': ti.events,
#         'dataframe': df[df['keeper'] == True],
#         'interm': interm,
#     }

def infer_object_map_from_zstack(
        input_file_path: str,
        output_folder_path: str,
        models: List[Model],
        pxmap_foreground_channel: int,
        pxmap_threshold: float,
        segmentation_channel: int,
        patches_channel: int,
        zmask_zindex: int = None,  # None for MIP,
        zmask_clip: int = None,
        zmask_type: str = 'boxes',
        zmask_filters: Dict = None,
        # zmask_expand_box_by: int = None,
        exports: RoiSetExportParams = RoiSetExportParams(),
        **kwargs,
) -> Dict:
    assert len(models) == 2
    pixel_classifier = models[0]
    assert isinstance(pixel_classifier, SemanticSegmentationModel)
    object_classifier = models[1]
    assert isinstance(object_classifier, InstanceSegmentationModel)

    ti = Timer()
    stack = generate_file_accessor(Path(input_file_path))
    fstem = Path(input_file_path).stem
    ti.click('file_input')

    # MIP if no zmask z-index is given, then classify pixels
    if isinstance(zmask_zindex, int):
        assert 0 < zmask_zindex < stack.nz
        zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex]
    else:
        zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True)
    if zmask_clip:
        zmask_data = rescale(zmask_data, zmask_clip)
    mip = InMemoryDataAccessor(
        zmask_data,
    )
    pxmap, _ = pixel_classifier.infer(mip)
    ti.click('infer_pixel_probability')

    if exports.pixel_probabilities:
        write_accessor_data_to_file(
            Path(output_folder_path) / 'pixel_probabilities' / (fstem + '.tif'),
            pxmap
        )
        ti.click('export_pixel_probability')

    obmask = InMemoryDataAccessor(
        pxmap.data > pxmap_threshold
    )
    ti.click('threshold_pixel_mask')

    # make zmask
    rois = RoiSet(
        obmask.get_one_channel_data(pxmap_foreground_channel),
        stack,
        mask_type=zmask_type,
        filters=zmask_filters,
        expand_box_by=exports.expand_box_by,
    )
    ti.click('generate_zmasks')

    # ti, stack, fstem, obmask, pxmap, obj_table = get_zmask_meta(
    #     input_file_path,
    #     pixel_classifier,
    #     segmentation_channel,
    #     pxmap_threshold,
    #     pxmap_foreground_channel=pxmap_foreground_channel,
    #     zmask_zindex=zmask_zindex,
    #     zmask_clip=zmask_clip,
    #     # zmask_expand_box_by=zmask_expand_box_by,
    #     zmask_filters=zmask_filters,
    #     zmask_type=zmask_type,
    #     **kwargs
    # )

    # # extract patches to accessor
    # patches_acc = get_patches_from_zmask_meta(
    #     stack.get_one_channel_data(patches_channel),
    #     obj_table.zmask_meta,
    #     rescale_clip=zmask_clip,
    #     make_3d=False,
    #     focus_metric='max_sobel',
    #     **kwargs
    # )
    #
    # # extract masks
    # patch_masks_acc = get_patch_masks_from_zmask_meta(
    #     stack,
    #     obj_table.zmask_meta,
    #     **kwargs
    # )

    # # send patches and mask stacks to object classifier
    # result_acc, _ = object_classifier.infer(patches_acc, patch_masks_acc)

    # labels_map = obj_table.interm['label_map']
    # output_map = np.zeros(labels_map.shape, dtype=labels_map.dtype)
    # assert labels_map.shape == obj_table.get_label_map().shape
    # assert labels_map.dtype == obj_table.get_label_map().dtype
    #
    # # assign labels to object map:
    # meta = []
    # for ii in range(0, len(obj_table.zmask_meta)):
    #     object_id = obj_table.zmask_meta[ii]['info'].label
    #     result_patch = mask_largest_object(result_acc.iat(ii))
    #     object_class = np.unique(result_patch)[1]
    #     output_map[labels_map == object_id] = object_class
    #     meta.append({'object_id': ii, 'object_class': object_id})

    object_class_map = rois.classify_by(patches_channel, object_classifier)

    # TODO: add ZMaskObjectTable method to export object map
    output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif'))
    write_accessor_data_to_file(
        output_path,
        object_class_map
    )
    ti.click('export_object_classes')

    if exports.patches_3d:
            rois.export_3d_patches(
                Path(output_folder_path) / '3d_patches',
                fstem,
                patches_channel,
                exports.patches_3d
            )
    ti.click('export_3d_patches')

    if exports.patches_2d_for_annotation:
        rois.export_2d_patches_for_annotation(
            Path(output_folder_path) / '2d_patches_annotation',
            fstem,
            patches_channel,
            exports.patches_2d_for_annotation
        )
    ti.click('export_2d_patches_for_annotation')

    if exports.patches_2d_for_training:
        rois.export_2d_patches_for_training(
            Path(output_folder_path) / '2d_patches_training',
            fstem,
            patches_channel,
            exports.patches_2d_for_training
        )
    ti.click('export_2d_patches_for_training')

    if exports.patch_masks:
        rois.export_patch_masks(
            Path(output_folder_path) / 'patch_masks',
            fstem,
            patches_channel,
            exports.patch_masks
        )

    if exports.annotated_z_stack:
        rois.export_annotated_zstack(
            Path(output_folder_path) / 'patch_masks',
            fstem,
            patches_channel,
            exports.annotated_z_stack
        )
    ti.click('export_annotated_zstack')

    return {
        'timer_results': ti.events,
        'dataframe':     rois.df,
        'interm': {},
        'output_path': output_path.__str__(),
    }



def transfer_ecotaxa_labels_to_patch_stacks(
    where_masks: str,
    where_patches: str,
    object_csv: str,
    ecotaxa_tsv: str,
    where_output: str,
    patch_size: tuple = (256, 256),
    tr_split=0.6,
    dilate_label_mask: bool = True, # to mitigate connected components error in ilastik
    allow_multiple_objects: bool = False,
) -> Dict:
    assert tr_split > 0.5 # reduce chance that low-probability objects are omitted from training

    # read patch metadata
    df_obj = pd.read_csv(
        object_csv,
    )
    df_ecotaxa = pd.read_csv(
        ecotaxa_tsv,
        sep='\t',
        header=[0],
        dtype={
            ('object_annotation_date', '[t]'): str,
            ('object_annotation_time', '[t]'): str,
            ('object_annotation_category_id', '[t]'): str,
        }
    )
    df_merge = pd.merge(df_obj, df_ecotaxa, left_on='patch_id', right_on='object_id')

    # assign each unique lowest-level annotation to a class index
    se_unique = pd.Series(
        df_merge.object_annotation_hierarchy.unique()
    )
    df_split = (
        se_unique.str.rsplit(
            pat='>', n=1, expand=True
        )
    )
    df_labels = pd.DataFrame({
        'annotation_class_id': df_split.index + 1,
        'hierarchy': se_unique,
        'annotation_class': df_split.loc[:, 1].str.lower()
    })

    # join patch filenames and annotation classes
    df_pf = pd.merge(
        df_merge[['patch_filename', 'object_annotation_hierarchy']],
        df_labels,
        left_on='object_annotation_hierarchy',
        right_on='hierarchy',
    )
    df_pl = df_pf[df_pf['object_annotation_hierarchy'].notnull()]

    # export annotation classes and their summary stats
    df_tr, df_te = train_test_split(df_pl, train_size=tr_split)
    # df_labels['counts'] = df_pl['annotation_class_id'].value_counts()
    df_labels = pd.merge(
        df_labels,
        pd.DataFrame(
            [df_pl.annotation_class_id.value_counts(), df_tr.annotation_class_id.value_counts(), df_te.annotation_class_id.value_counts()],
            index=['total', 'to_train', 'to_test']
        ).T,
        left_on='annotation_class_id',
        right_index=True,
        how='outer'
    )
    df_labels.loc[df_labels.to_train.isna(), 'to_train'] = 0
    df_labels.loc[df_labels.to_test.isna(), 'to_test'] = 0
    for col in ['total', 'to_train', 'to_test']:
        df_labels.loc[df_labels[col].isna(), col] = 0
    df_labels.to_csv(Path(where_output) / 'labels_key.csv', index=False)

    # export patches as z-stacks
    for (dfk, dfv) in {'train': df_tr, 'test': df_te}.items():
        zstack_keys = ['mask', 'label', 'raw']
        zstacks = {f'{dfk}_{zsk}': np.zeros((*patch_size, 1, len(dfv)), dtype='uint8') for zsk in zstack_keys}
        stack_meta = []
        for fi, pl in enumerate(dfv.itertuples(name='PatchFile')):
            fn = pl._asdict()['patch_filename']
            ac = pl._asdict()['annotation_class']
            aci = pl._asdict()['annotation_class_id']

            stack_meta.append({'zi': fi, 'patch_filename': fn, 'annotation_class': ac, 'annotation_class_id': aci})
            acc_bm = generate_file_accessor(Path(where_masks) / fn)
            assert acc_bm.is_mask()
            assert acc_bm.hw == patch_size, f'Unexpected patch size {patch_size}'
            assert acc_bm.chroma == 1
            assert acc_bm.nz == 1
            mask = acc_bm.data[:, :, 0, 0]
            if dilate_label_mask:
                mask = dilation(mask)
            if not allow_multiple_objects:
                ob_id = label(acc_bm.data[:, :, 0, 0])
                mask = mask_largest_object(ob_id)
            zstacks[dfk + '_mask'][:, :, 0, fi] = mask
            zstacks[dfk + '_label'][:, :, 0, fi] = (mask == 255) * aci

            acc_pa = generate_file_accessor(Path(where_patches) / fn)
            zstacks[dfk + '_raw'][:, :, :, fi] = acc_pa.data[:, :, :, 0]

        for k in zstacks.keys():
            write_accessor_data_to_file(Path(where_output) / f'zstack_{k}.tif', InMemoryDataAccessor(zstacks[k]))

        pd.DataFrame(stack_meta).to_csv(Path(where_output) / f'{dfk}_stack.csv', index=False)