From 16d96d55065c7c451e3aad115441df4304f1e440 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 20 Dec 2023 14:33:29 +0100 Subject: [PATCH] Implemented but did not yet test adding object classification result to RoiSet dataframe --- extensions/chaeo/params.py | 7 --- extensions/chaeo/workflows.py | 109 +++++++++++++++++----------------- extensions/chaeo/zmask.py | 52 ++++++++++++---- 3 files changed, 94 insertions(+), 74 deletions(-) diff --git a/extensions/chaeo/params.py b/extensions/chaeo/params.py index e516f49b..796469e9 100644 --- a/extensions/chaeo/params.py +++ b/extensions/chaeo/params.py @@ -26,11 +26,6 @@ class RoiSetExportParams(BaseModel): annotated_z_stack: Union[AnnotatedZStackParams, None] = None -class RoiClassifierValue(BaseModel): - name: str - value: int - - class RoiFilterRange(BaseModel): min: float max: float @@ -39,5 +34,3 @@ class RoiFilterRange(BaseModel): class RoiFilter(BaseModel): area: Union[RoiFilterRange, None] = None solidity: Union[RoiFilterRange, None] = None - classifiers: List[RoiClassifierValue] = [] - diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py index 6e9732f6..50428f09 100644 --- a/extensions/chaeo/workflows.py +++ b/extensions/chaeo/workflows.py @@ -12,7 +12,7 @@ from sklearn.model_selection import train_test_split from extensions.chaeo.accessors import MonoPatchStack from extensions.chaeo.annotators import draw_boxes_on_3d_image from extensions.chaeo.models import PatchStackObjectClassifier -from extensions.chaeo.params import ZMaskExportParams +from extensions.chaeo.params import RoiSetExportParams from extensions.chaeo.process import mask_largest_object from extensions.chaeo.products import export_patches_from_zstack, export_patch_masks_from_zstack, export_multichannel_patches_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta from extensions.chaeo.zmask import project_stack_from_focal_points, RoiSet @@ -240,7 +240,7 @@ def infer_object_map_from_zstack( zmask_type: str = 'boxes', zmask_filters: Dict = None, # zmask_expand_box_by: int = None, - exports: ZMaskExportParams = None, + exports: RoiSetExportParams = None, **kwargs, ) -> Dict: assert len(models) == 2 @@ -281,7 +281,7 @@ def infer_object_map_from_zstack( ti.click('threshold_pixel_mask') # make zmask - obj_table = RoiSet( + rois = RoiSet( obmask.get_one_channel_data(pxmap_foreground_channel), stack.get_one_channel_data(segmentation_channel), mask_type=zmask_type, @@ -291,7 +291,7 @@ def infer_object_map_from_zstack( ti.click('generate_zmasks') # record pixel scale - obj_table.df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X')) + rois.df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X')) # ti, stack, fstem, obmask, pxmap, obj_table = get_zmask_meta( # input_file_path, @@ -307,79 +307,79 @@ def infer_object_map_from_zstack( # **kwargs # ) - # extract patches to accessor - patches_acc = get_patches_from_zmask_meta( - stack.get_one_channel_data(patches_channel), - obj_table.zmask_meta, - rescale_clip=zmask_clip, - make_3d=False, - focus_metric='max_sobel', - **kwargs - ) - - # TODO: make this a method of ZMaskObjectTable class - # extract masks - patch_masks_acc = get_patch_masks_from_zmask_meta( - stack, - obj_table.zmask_meta, - **kwargs - ) - - # TODO: add ZMaskObjectTable method to apply object classification results as new DataFrame column - # send patches and mask stacks to object classifier - result_acc, _ = object_classifier.infer(patches_acc, patch_masks_acc) - - labels_map = obj_table.interm['label_map'] - output_map = np.zeros(labels_map.shape, dtype=labels_map.dtype) - assert labels_map.shape == obj_table.get_label_map().shape - assert labels_map.dtype == obj_table.get_label_map().dtype + # # extract patches to accessor + # patches_acc = get_patches_from_zmask_meta( + # stack.get_one_channel_data(patches_channel), + # obj_table.zmask_meta, + # rescale_clip=zmask_clip, + # make_3d=False, + # focus_metric='max_sobel', + # **kwargs + # ) + # + # # extract masks + # patch_masks_acc = get_patch_masks_from_zmask_meta( + # stack, + # obj_table.zmask_meta, + # **kwargs + # ) - # assign labels to object map: - meta = [] - for ii in range(0, len(obj_table.zmask_meta)): - object_id = obj_table.zmask_meta[ii]['info'].label - result_patch = mask_largest_object(result_acc.iat(ii)) - object_class = np.unique(result_patch)[1] - output_map[labels_map == object_id] = object_class - meta.append({'object_id': ii, 'object_class': object_id}) + # # send patches and mask stacks to object classifier + # result_acc, _ = object_classifier.infer(patches_acc, patch_masks_acc) + + # labels_map = obj_table.interm['label_map'] + # output_map = np.zeros(labels_map.shape, dtype=labels_map.dtype) + # assert labels_map.shape == obj_table.get_label_map().shape + # assert labels_map.dtype == obj_table.get_label_map().dtype + # + # # assign labels to object map: + # meta = [] + # for ii in range(0, len(obj_table.zmask_meta)): + # object_id = obj_table.zmask_meta[ii]['info'].label + # result_patch = mask_largest_object(result_acc.iat(ii)) + # object_class = np.unique(result_patch)[1] + # output_map[labels_map == object_id] = object_class + # meta.append({'object_id': ii, 'object_class': object_id}) + + object_class_map = rois.classify_by(patches_channel) # TODO: add ZMaskObjectTable method to export object map output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif')) write_accessor_data_to_file( output_path, - InMemoryDataAccessor(output_map) + InMemoryDataAccessor(object_class_map) ) ti.click('export_object_classes') if exports.patches_3d: - obj_table.export_3d_patches( - Path(output_folder_path) / '3d_patches', - fstem, - patches_channel, - exports.patches_3d - ) - ti.click('export_3d_patches') + rois.export_3d_patches( + Path(output_folder_path) / '3d_patches', + fstem, + patches_channel, + exports.patches_3d + ) + ti.click('export_3d_patches') if exports.patches_2d_for_annotation: - obj_table.export_2d_patches_for_annotation( + rois.export_2d_patches_for_annotation( Path(output_folder_path) / '2d_patches_annotation', fstem, patches_channel, exports.patches_2d_for_annotation ) - ti.click('export_2d_patches_for_annotation') + ti.click('export_2d_patches_for_annotation') if exports.patches_2d_for_training: - obj_table.export_2d_patches_for_training( + rois.export_2d_patches_for_training( Path(output_folder_path) / '2d_patches_training', fstem, patches_channel, exports.patches_2d_for_training ) - ti.click('export_2d_patches_for_training') + ti.click('export_2d_patches_for_training') if exports.patch_masks: - obj_table.export_patch_masks( + rois.export_patch_masks( Path(output_folder_path) / 'patch_masks', fstem, patches_channel, @@ -387,18 +387,17 @@ def infer_object_map_from_zstack( ) if exports.annotated_z_stack: - obj_table.export_annotated_zstack( + rois.export_annotated_zstack( Path(output_folder_path) / 'patch_masks', fstem, patches_channel, exports.annotated_z_stack ) - ti.click('export_annotated_zstack') - + ti.click('export_annotated_zstack') return { 'timer_results': ti.events, - 'dataframe': pd.DataFrame(meta), + 'dataframe': rois.df, 'interm': {}, 'output_path': output_path.__str__(), } diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py index 67a78334..69e128ef 100644 --- a/extensions/chaeo/zmask.py +++ b/extensions/chaeo/zmask.py @@ -8,10 +8,11 @@ from sklearn.preprocessing import PolynomialFeatures from sklearn.linear_model import LinearRegression from extensions.chaeo.annotators import draw_boxes_on_3d_image -from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack +from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta from extensions.chaeo.params import RoiFilter, RoiSetExportParams +from extensions.chaeo.process import mask_largest_object from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file - +from model_server.models import InstanceSegmentationModel class RoiSet(object): @@ -30,13 +31,11 @@ class RoiSet(object): filters=filters, mask_type=mask_type, expand_box_by=expand_box_by - ) # currently, some methods can add columns to self.df + ) self.acc_raw = acc_raw self.count = len(self.zmask_meta) - - def get_label_map(self): - return self.interm.lamap + self.object_id_labels = self.interm['label_map'] def get_argmax(self): return self.interm.argmax @@ -152,19 +151,48 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected - def classify_by(self, column_name, object_classification_model, patch_basis=True): - # add one column to df where each value is an integer - pass + def get_raw_patches(self, channel): + return get_patches_from_zmask_meta(self.acc_raw(channel), self.zmask_meta) - def get_raw_patches(self, filters: RoiFilter): - pass + def get_patch_masks(self): + return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta) + + def classify_by(self, channel, object_classification_model: InstanceSegmentationModel): + # do this on a patch basis, i.e. only one object per frame + obmap_patches = object_classification_model.label_instance_class( + self.get_raw_patches(channel), + self.get_patch_masks() + ) - def get_patch_masks(self, filters: RoiFilter): + lamap = self.object_id_labels + output_map = np.zeros(lamap.shape, dtype=lamap.dtype) + self.df['instance_class'] = np.nan + + # assign labels to object map: + for ii in range(0, self.count): + object_id = self.zmask_meta[ii]['info'].label + result_patch = mask_largest_object(obmap_patches.iat(ii)) + object_class = np.unique(result_patch)[1] + output_map[self.object_id_labels == object_id] = object_class + self.df[object_id, 'instance_class'] = object_class + + return InMemoryDataAccessor(output_map) + + # TODO: test + def get_object_mask_by_id(self, obj_id): + return self.object_id_labels == obj_id + + def get_object_mask_by_class(self, class_id): + return self.object_id_labels == class_id + + # TODO: implement + def get_object_patch_by_id(self, obj_id): pass def get_object_map(self, filters: RoiFilter): pass + def build_zmask_from_object_mask( obmask: GenericImageDataAccessor, zstack: GenericImageDataAccessor, -- GitLab