from pathlib import Path from typing import Dict, List from uuid import uuid4 import numpy as np import pandas as pd from skimage.measure import label, regionprops_table from skimage.morphology import dilation from sklearn.model_selection import train_test_split from extensions.chaeo.accessors import MonoPatchStack from extensions.chaeo.annotators import draw_boxes_on_3d_image from extensions.chaeo.models import PatchStackObjectClassifier from extensions.chaeo.process import mask_largest_object from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta from extensions.chaeo.zmask import build_zmask_from_object_mask, project_stack_from_focal_points from extensions.ilastik.models import IlastikPixelClassifierModel from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.models import Model from model_server.process import rescale from model_server.workflows import Timer def get_zmask_meta( input_file_path: str, ilastik_pixel_classifier: IlastikPixelClassifierModel, segmentation_channel: int, pxmap_threshold: float, pxmap_foreground_channel: int = 0, zmask_zindex: int = None, zmask_clip: int = None, zmask_filters: Dict = None, zmask_type: str = 'boxes', **kwargs, ) -> tuple: ti = Timer() stack = generate_file_accessor(Path(input_file_path)) fstem = Path(input_file_path).stem ti.click('file_input') # MIP if no zmask z-index is given, then classify pixels if isinstance(zmask_zindex, int): assert 0 < zmask_zindex < stack.nz zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data[:, :, :, zmask_zindex] else: zmask_data = stack.get_one_channel_data(channel=segmentation_channel).data.max(axis=-1, keepdims=True) if zmask_clip: zmask_data = rescale(zmask_data, zmask_clip) mip = InMemoryDataAccessor( zmask_data, ) pxmap, _ = ilastik_pixel_classifier.infer(mip) ti.click('infer_pixel_probability') obmask = InMemoryDataAccessor( pxmap.data > pxmap_threshold ) ti.click('threshold_pixel_mask') # make zmask zmask, zmask_meta, df, interm = build_zmask_from_object_mask( obmask.get_one_channel_data(pxmap_foreground_channel), stack.get_one_channel_data(segmentation_channel), mask_type=zmask_type, filters=zmask_filters, expand_box_by=kwargs['zmask_expand_box_by'], ) ti.click('generate_zmasks') # record pixel scale df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X')) return ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm # TODO: unpack and validate inputs def export_patches_from_multichannel_zstack( input_file_path: str, output_folder_path: str, models: List[Model], pxmap_threshold: float, pxmap_foreground_channel: int, segmentation_channel: int, patches_channel: int, zmask_zindex: int = None, # None for MIP, zmask_clip: int = None, zmask_type: str = 'boxes', zmask_filters: Dict = None, zmask_expand_box_by: int = None, export_pixel_probabilities=True, export_2d_patches_for_training=True, export_2d_patches_for_annotation=True, draw_bounding_box_on_2d_patch=True, draw_contour_on_2d_patch=False, draw_mask_on_2d_patch=False, export_3d_patches=True, export_annotated_zstack=True, draw_label_on_zstack=False, export_patch_masks=True, rgb_overlay_channels=(None, None, None), rgb_overlay_weights=(1.0, 1.0, 1.0), ) -> Dict: pixel_classifier = models[0] ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm = get_zmask_meta( input_file_path, pixel_classifier, segmentation_channel, pxmap_threshold, pxmap_foreground_channel=pxmap_foreground_channel, zmask_zindex=zmask_zindex, zmask_clip=zmask_clip, zmask_expand_box_by=zmask_expand_box_by, zmask_filters=zmask_filters, zmask_type=zmask_type, ) if export_pixel_probabilities: write_accessor_data_to_file( Path(output_folder_path) / 'pixel_probabilities' / (fstem + '.tif'), pxmap ) ti.click('export_pixel_probability') if export_3d_patches and len(zmask_meta) > 0: files = export_patches_from_zstack( Path(output_folder_path) / '3d_patches', stack.get_one_channel_data(patches_channel), zmask_meta, prefix=fstem, draw_bounding_box=False, rescale_clip=0.001, make_3d=True, ) ti.click('export_3d_patches') if export_2d_patches_for_annotation and len(zmask_meta) > 0: files = export_multichannel_patches_from_zstack( Path(output_folder_path) / '2d_patches_annotation', stack, zmask_meta, prefix=fstem, rescale_clip=0.001, make_3d=False, focus_metric='max_sobel', ch_white=patches_channel, ch_rgb_overlay=rgb_overlay_channels, draw_bounding_box=draw_bounding_box_on_2d_patch, bounding_box_channel=1, bounding_box_linewidth=2, draw_contour=draw_contour_on_2d_patch, draw_mask=draw_mask_on_2d_patch, overlay_gain=rgb_overlay_weights, ) df_patches = pd.DataFrame(files) ti.click('export_2d_patches') # associate 2d patches, dropping labeled objects that were not exported as patches df = pd.merge(df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') # prepopulate patch UUID df['patch_id'] = df.apply(lambda _: uuid4(), axis=1) if export_2d_patches_for_training and len(zmask_meta) > 0: files = export_multichannel_patches_from_zstack( Path(output_folder_path) / '2d_patches_training', stack.get_one_channel_data(patches_channel), zmask_meta, prefix=fstem, rescale_clip=0.001, make_3d=False, focus_metric='max_sobel', ) ti.click('export_2d_patches') if export_patch_masks and len(zmask_meta) > 0: files = export_patch_masks_from_zstack( Path(output_folder_path) / 'patch_masks', stack.get_one_channel_data(patches_channel), zmask_meta, prefix=fstem, ) if export_annotated_zstack: annotated = InMemoryDataAccessor( draw_boxes_on_3d_image( stack.get_one_channel_data(patches_channel).data, zmask_meta, add_label=draw_label_on_zstack, ) ) write_accessor_data_to_file( Path(output_folder_path) / 'annotated_zstacks' / (fstem + '.tif'), annotated ) ti.click('export_annotated_zstack') # generate multichannel projection from label centroids dff = df[df['keeper']] if len(zmask_meta) > 0: interm['projected'] = project_stack_from_focal_points( dff['centroid-0'].to_numpy(), dff['centroid-1'].to_numpy(), dff['zi'].to_numpy(), stack, degree=4, ) else: # else just return MIP interm['projected'] = stack.data.max(axis=-1) return { 'pixel_model_id': pixel_classifier.model_id, 'input_filepath': input_file_path, 'number_of_objects': len(zmask_meta), 'pixeL_scale_in_micrometers': stack.pixel_scale_in_micrometers, 'success': True, 'timer_results': ti.events, 'dataframe': df[df['keeper'] == True], 'interm': interm, } def infer_object_map_from_zstack( input_file_path: str, output_folder_path: str, models: List[Model], pxmap_threshold: float, pxmap_foreground_channel: int, segmentation_channel: int, patches_channel: int, zmask_zindex: int = None, # None for MIP, zmask_clip: int = None, zmask_type: str = 'boxes', zmask_filters: Dict = None, # zmask_expand_box_by: int = None, **kwargs, ) -> Dict: assert len(models) == 2 pixel_classifier = models[0] assert isinstance(pixel_classifier, IlastikPixelClassifierModel) object_classifier = models[1] assert isinstance(object_classifier, PatchStackObjectClassifier) ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm = get_zmask_meta( input_file_path, pixel_classifier, segmentation_channel, pxmap_threshold, pxmap_foreground_channel=pxmap_foreground_channel, zmask_zindex=zmask_zindex, zmask_clip=zmask_clip, # zmask_expand_box_by=zmask_expand_box_by, zmask_filters=zmask_filters, zmask_type=zmask_type, **kwargs ) # extract patches to accessor patches_acc = get_patches_from_zmask_meta( stack.get_one_channel_data(patches_channel), zmask_meta, rescale_clip=zmask_clip, make_3d=False, focus_metric='max_sobel', **kwargs ) # extract masks patch_masks_acc = get_patch_masks_from_zmask_meta( stack, zmask_meta, **kwargs ) # send patches and mask stacks to object classifier result_acc, _ = object_classifier.infer(patches_acc, patch_masks_acc) labels_map = interm['label_map'] output_map = np.zeros(labels_map.shape, dtype=labels_map.dtype) assert labels_map.shape == interm['label_map'].shape assert labels_map.dtype == interm['label_map'].dtype # assign labels to object map: meta = [] for ii in range(0, len(zmask_meta)): object_id = zmask_meta[ii]['info'].label result_patch = mask_largest_object(result_acc.iat(ii)) object_class = np.unique(result_patch)[1] output_map[labels_map == object_id] = object_class meta.append({'object_id': ii, 'object_class': object_id}) output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif')) write_accessor_data_to_file( output_path, InMemoryDataAccessor(output_map) ) ti.click('export_object_classes') return { 'timer_results': ti.events, 'dataframe': pd.DataFrame(meta), 'interm': {}, 'output_path': output_path.__str__(), } def transfer_ecotaxa_labels_to_patch_stacks( where_masks: str, where_patches: str, object_csv: str, ecotaxa_tsv: str, where_output: str, patch_size: tuple = (256, 256), tr_split=0.6, dilate_label_mask: bool = True, # to mitigate connected components error in ilastik allow_multiple_objects: bool = False, ) -> Dict: assert tr_split > 0.5 # reduce chance that low-probability objects are omitted from training # read patch metadata df_obj = pd.read_csv( object_csv, ) df_ecotaxa = pd.read_csv( ecotaxa_tsv, sep='\t', header=[0], dtype={ ('object_annotation_date', '[t]'): str, ('object_annotation_time', '[t]'): str, ('object_annotation_category_id', '[t]'): str, } ) df_merge = pd.merge(df_obj, df_ecotaxa, left_on='patch_id', right_on='object_id') # assign each unique lowest-level annotation to a class index se_unique = pd.Series( df_merge.object_annotation_hierarchy.unique() ) df_split = ( se_unique.str.rsplit( pat='>', n=1, expand=True ) ) df_labels = pd.DataFrame({ 'annotation_class_id': df_split.index + 1, 'hierarchy': se_unique, 'annotation_class': df_split.loc[:, 1].str.lower() }) # join patch filenames and annotation classes df_pf = pd.merge( df_merge[['patch_filename', 'object_annotation_hierarchy']], df_labels, left_on='object_annotation_hierarchy', right_on='hierarchy', ) df_pl = df_pf[df_pf['object_annotation_hierarchy'].notnull()] # export annotation classes and their summary stats df_tr, df_te = train_test_split(df_pl, train_size=tr_split) # df_labels['counts'] = df_pl['annotation_class_id'].value_counts() df_labels = pd.merge( df_labels, pd.DataFrame( [df_pl.annotation_class_id.value_counts(), df_tr.annotation_class_id.value_counts(), df_te.annotation_class_id.value_counts()], index=['total', 'to_train', 'to_test'] ).T, left_on='annotation_class_id', right_index=True, how='outer' ) df_labels.loc[df_labels.to_train.isna(), 'to_train'] = 0 df_labels.loc[df_labels.to_test.isna(), 'to_test'] = 0 for col in ['total', 'to_train', 'to_test']: df_labels.loc[df_labels[col].isna(), col] = 0 df_labels.to_csv(Path(where_output) / 'labels_key.csv', index=False) # export patches as z-stacks for (dfk, dfv) in {'train': df_tr, 'test': df_te}.items(): zstack_keys = ['mask', 'label', 'raw'] zstacks = {f'{dfk}_{zsk}': np.zeros((*patch_size, 1, len(dfv)), dtype='uint8') for zsk in zstack_keys} stack_meta = [] for fi, pl in enumerate(dfv.itertuples(name='PatchFile')): fn = pl._asdict()['patch_filename'] ac = pl._asdict()['annotation_class'] aci = pl._asdict()['annotation_class_id'] stack_meta.append({'zi': fi, 'patch_filename': fn, 'annotation_class': ac, 'annotation_class_id': aci}) acc_bm = generate_file_accessor(Path(where_masks) / fn) assert acc_bm.is_mask() assert acc_bm.hw == patch_size, f'Unexpected patch size {patch_size}' assert acc_bm.chroma == 1 assert acc_bm.nz == 1 mask = acc_bm.data[:, :, 0, 0] if dilate_label_mask: mask = dilation(mask) if not allow_multiple_objects: ob_id = label(acc_bm.data[:, :, 0, 0]) mask = mask_largest_object(ob_id) zstacks[dfk + '_mask'][:, :, 0, fi] = mask zstacks[dfk + '_label'][:, :, 0, fi] = (mask == 255) * aci acc_pa = generate_file_accessor(Path(where_patches) / fn) zstacks[dfk + '_raw'][:, :, :, fi] = acc_pa.data[:, :, :, 0] for k in zstacks.keys(): write_accessor_data_to_file(Path(where_output) / f'zstack_{k}.tif', InMemoryDataAccessor(zstacks[k])) pd.DataFrame(stack_meta).to_csv(Path(where_output) / f'{dfk}_stack.csv', index=False)