Skip to content
Snippets Groups Projects
workflows.py 6.48 KiB
Newer Older
from skimage.measure import label, regionprops_table
from skimage.morphology import dilation
from sklearn.model_selection import train_test_split
from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
from extensions.chaeo.process import mask_largest_object
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import Model, InstanceSegmentationModel, SemanticSegmentationModel

def infer_object_map_from_zstack(
        input_file_path: str,
        output_folder_path: str,
        models: List[Model],
        segmentation_channel: int,
        patches_channel: int,
        zmask_zindex: int = None,  # None for MIP,
        roi_params: RoiSetMetaParams = RoiSetMetaParams(),
        export_params: RoiSetExportParams = RoiSetExportParams(),
    assert isinstance(models['pixel_classifier']['model'], SemanticSegmentationModel)
    assert isinstance(models['object_classifier']['model'], InstanceSegmentationModel)
    ti = Timer()
    stack = generate_file_accessor(Path(input_file_path))
    fstem = Path(input_file_path).stem
    ti.click('file_input')

    # MIP if no zmask z-index is given, then classify pixels
    if isinstance(zmask_zindex, int):
        assert 0 < zmask_zindex < stack.nz
        zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex]
    else:
        zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True)
    mip = InMemoryDataAccessor(zmask_data)
    mip_mask = models['pixel_classifier']['model'].label_pixel_class(mip, **models['pixel_classifier']['params'])
    rois = RoiSet(mip_mask, stack, params=roi_params)
    rois.classify_by(patches_channel, models['object_classifier']['model'])
    ti.click('classify_objects')
    rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params)
    ti.click('export_roi_products')
        'output_path': output_folder_path,
def transfer_ecotaxa_labels_to_patch_stacks(
    where_masks: str,
    where_patches: str,
    object_csv: str,
    ecotaxa_tsv: str,
    where_output: str,
    patch_size: tuple = (256, 256),
    tr_split=0.6,
    dilate_label_mask: bool = True, # to mitigate connected components error in ilastik
    allow_multiple_objects: bool = False,
    assert tr_split > 0.5 # reduce chance that low-probability objects are omitted from training

    # read patch metadata
    df_obj = pd.read_csv(
        object_csv,
    )
    df_ecotaxa = pd.read_csv(
        ecotaxa_tsv,
        dtype={
            ('object_annotation_date', '[t]'): str,
            ('object_annotation_time', '[t]'): str,
            ('object_annotation_category_id', '[t]'): str,
        }
    )
    df_merge = pd.merge(df_obj, df_ecotaxa, left_on='patch_id', right_on='object_id')

    # assign each unique lowest-level annotation to a class index
    se_unique = pd.Series(
        df_merge.object_annotation_hierarchy.unique()
    )
    df_split = (
        se_unique.str.rsplit(
            pat='>', n=1, expand=True
        )
    )
    df_labels = pd.DataFrame({
        'annotation_class_id': df_split.index + 1,
        'hierarchy': se_unique,
        'annotation_class': df_split.loc[:, 1].str.lower()
    })

    # join patch filenames and annotation classes
    df_pf = pd.merge(
        df_merge[['patch_filename', 'object_annotation_hierarchy']],
        df_labels,
        left_on='object_annotation_hierarchy',
        right_on='hierarchy',
    )
    df_pl = df_pf[df_pf['object_annotation_hierarchy'].notnull()]

    # export annotation classes and their summary stats
    df_tr, df_te = train_test_split(df_pl, train_size=tr_split)
    # df_labels['counts'] = df_pl['annotation_class_id'].value_counts()
    df_labels = pd.merge(
        df_labels,
        pd.DataFrame(
            [df_pl.annotation_class_id.value_counts(), df_tr.annotation_class_id.value_counts(), df_te.annotation_class_id.value_counts()],
            index=['total', 'to_train', 'to_test']
        ).T,
        left_on='annotation_class_id',
        right_index=True,
        how='outer'
    )
    df_labels.loc[df_labels.to_train.isna(), 'to_train'] = 0
    df_labels.loc[df_labels.to_test.isna(), 'to_test'] = 0
        df_labels.loc[df_labels[col].isna(), col] = 0
    df_labels.to_csv(Path(where_output) / 'labels_key.csv', index=False)

    # export patches as z-stacks
    for (dfk, dfv) in {'train': df_tr, 'test': df_te}.items():
        zstack_keys = ['mask', 'label', 'raw']
        zstacks = {f'{dfk}_{zsk}': np.zeros((*patch_size, 1, len(dfv)), dtype='uint8') for zsk in zstack_keys}
        stack_meta = []
        for fi, pl in enumerate(dfv.itertuples(name='PatchFile')):
            fn = pl._asdict()['patch_filename']
            ac = pl._asdict()['annotation_class']
            aci = pl._asdict()['annotation_class_id']
            stack_meta.append({'zi': fi, 'patch_filename': fn, 'annotation_class': ac, 'annotation_class_id': aci})
            acc_bm = generate_file_accessor(Path(where_masks) / fn)
            assert acc_bm.is_mask()
            assert acc_bm.hw == patch_size, f'Unexpected patch size {patch_size}'
            assert acc_bm.chroma == 1
            assert acc_bm.nz == 1
            mask = acc_bm.data[:, :, 0, 0]
            if dilate_label_mask:
                mask = dilation(mask)
            if not allow_multiple_objects:
                ob_id = label(acc_bm.data[:, :, 0, 0])
                mask = mask_largest_object(ob_id)
            zstacks[dfk + '_mask'][:, :, 0, fi] = mask
            zstacks[dfk + '_label'][:, :, 0, fi] = (mask == 255) * aci
            acc_pa = generate_file_accessor(Path(where_patches) / fn)
            zstacks[dfk + '_raw'][:, :, :, fi] = acc_pa.data[:, :, :, 0]
        for k in zstacks.keys():
            write_accessor_data_to_file(Path(where_output) / f'zstack_{k}.tif', InMemoryDataAccessor(zstacks[k]))
        pd.DataFrame(stack_meta).to_csv(Path(where_output) / f'{dfk}_stack.csv', index=False)