diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py index 9cd7078b3a0a51c2abf02055d4549efd11d73b06..7289c58ec63dd4fda637a82f57120f15a95c12a4 100644 --- a/extensions/chaeo/workflows.py +++ b/extensions/chaeo/workflows.py @@ -7,8 +7,9 @@ import pandas as pd from skimage.morphology import dilation from sklearn.model_selection import train_test_split -from extensions.ilastik.models import IlastikPixelClassifierModel +from extensions.chaeo.accessors import MonoPatchStack from extensions.chaeo.annotators import draw_boxes_on_3d_image +from extensions.chaeo.models import PatchStackObjectClassifier from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta from extensions.chaeo.zmask import build_zmask_from_object_mask, project_stack_from_focal_points from extensions.ilastik.models import IlastikPixelClassifierModel @@ -224,7 +225,11 @@ def get_object_map_from_zstack( zmask_expand_box_by: int = None, **kwargs, ) -> Dict: + assert len(models) == 2 pixel_classifier = models[0] + assert isinstance(pixel_classifier, IlastikPixelClassifierModel) + object_classifier = models[1] + assert isinstance(object_classifier, PatchStackObjectClassifier) ti, stack, fstem, obmask, pxmap, zmask, zmask_meta, df, interm = get_zmask_meta( input_file_path, @@ -256,14 +261,25 @@ def get_object_map_from_zstack( ) # send patches and mask stacks to object classifier - mod = PatchStackObjectClassifier({'project_file': where_patch_stack / ilp}) - result_acc, _ = mod.infer(raw, mask) - # write_accessor_data_to_file(where_patch_stack / f'zstack_train_result_{suffix}.tif', result_acc) + result_acc, _ = MonoPatchStack( + object_classifier.infer(patches_acc, patch_masks_acc) + ) + + object_labels_map = np.copy(interm['label_map']) + assert object_labels_map.shape == interm['label_map'].shape + assert object_labels_map.dtype == interm['label_map'].dtype # assign labels to object map: - for ii, mi in enumerate(zmask_meta): - obj = mi['info'] - la = obj.label + for ii in range(0, len(zmask_meta)): + mi = zmask_meta[ii] + object_label_id = mi['info'].label + result_label_map = result_acc.iat(ii) + unique_values = np.unique(result_label_map) + assert len(unique_values) == 2 + assert unique_values[0] == 0 + inferred_class = result_acc.iat(ii) + ii_mask = object_labels_map == object_label_id + object_labels_map[ii_mask] = unique_values[1] patch = patches_acc.iat(ii)