From 01adb4b2372fda07c0c75ff89a8cc0dc5e885b20 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 21 Dec 2023 12:03:09 +0100 Subject: [PATCH] Pass params model straight to product exporters, so less code dedicated to matching parameters --- extensions/chaeo/annotators.py | 2 +- extensions/chaeo/products.py | 14 +++++++------- extensions/chaeo/workflows.py | 1 - extensions/chaeo/zmask.py | 29 ++++++++++------------------- 4 files changed, 18 insertions(+), 28 deletions(-) diff --git a/extensions/chaeo/annotators.py b/extensions/chaeo/annotators.py index eb4996c3..687b4ca1 100644 --- a/extensions/chaeo/annotators.py +++ b/extensions/chaeo/annotators.py @@ -23,7 +23,7 @@ def draw_boxes_on_2d_image(yx_img, boxes, **kwargs): draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth) - if kwargs.get('add_label') is True: + if kwargs.get('draw_label') is True: draw.text((xm, y0), f'{la:04d}', fill='white', anchor='mb') return pilimg diff --git a/extensions/chaeo/products.py b/extensions/chaeo/products.py index 0f1ac976..155a544b 100644 --- a/extensions/chaeo/products.py +++ b/extensions/chaeo/products.py @@ -331,8 +331,8 @@ def export_multichannel_patches_from_zstack( where: Path, stack: GenericImageDataAccessor, zmask_meta: list, - ch_rgb_overlay: tuple = None, - overlay_gain: tuple = (1.0, 1.0, 1.0), + rgb_overlay_channels: list = None, + rgb_overlay_weights: list = [1.0, 1.0, 1.0], ch_white: int = None, **kwargs ): @@ -360,17 +360,17 @@ def export_multichannel_patches_from_zstack( else: mdata = idata - if ch_rgb_overlay: - assert len(ch_rgb_overlay) == 3 - assert len(overlay_gain) == 3 - for ii, ci in enumerate(ch_rgb_overlay): + if rgb_overlay_channels: + assert len(rgb_overlay_channels) == 3 + assert len(rgb_overlay_weights) == 3 + for ii, ci in enumerate(rgb_overlay_channels): if ci is None: continue assert isinstance(ci, int) assert ci < stack.chroma mdata[:, :, ii, :] = _safe_add( mdata[:, :, ii, :], - overlay_gain[ii], + rgb_overlay_weights[ii], idata[:, :, ci, :] ) diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py index 8dd7554f..811aed5f 100644 --- a/extensions/chaeo/workflows.py +++ b/extensions/chaeo/workflows.py @@ -122,7 +122,6 @@ def infer_object_map_from_zstack( Path(output_folder_path) / 'patch_masks', fstem, patches_channel, - exports.patch_masks ) if exports.annotated_z_stack: diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py index 76e96bf8..77854e3e 100644 --- a/extensions/chaeo/zmask.py +++ b/extensions/chaeo/zmask.py @@ -9,7 +9,7 @@ from sklearn.linear_model import LinearRegression from extensions.chaeo.annotators import draw_boxes_on_3d_image from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta -from extensions.chaeo.params import RoiFilter, RoiSetExportParams +from extensions.chaeo.params import AnnotatedZStackParams, PatchParams, RoiFilter, RoiSetExportParams from extensions.chaeo.process import mask_largest_object from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.models import InstanceSegmentationModel @@ -40,7 +40,7 @@ class RoiSet(object): def get_argmax(self): return self.interm.argmax - def export_3d_patches(self, where, prefix, channel, params: RoiSetExportParams): + def export_3d_patches(self, where, prefix, channel, params: PatchParams): if not self.count: return files = export_patches_from_zstack( @@ -48,12 +48,11 @@ class RoiSet(object): self.acc_raw.get_one_channel_data(channel), self.zmask_meta, prefix=prefix, - draw_bounding_box=params.draw_bounding_box, - rescale_clip=params.rescale_clip, make_3d=True, + **params.__dict__, ) - def export_2d_patches_for_training(self, where, prefix, channel, params: RoiSetExportParams): + def export_2d_patches_for_training(self, where, prefix, channel, params: PatchParams): if not self.count: return files = export_multichannel_patches_from_zstack( @@ -62,15 +61,14 @@ class RoiSet(object): self.zmask_meta, ch_white=channel, prefix=prefix, - rescale_clip=params.rescale_clip, make_3d=False, - focus_metric=params.focus_metric, + **params.__dict__, ) df_patches = pd.DataFrame(files) self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1) - def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams): + def export_2d_patches_for_annotation(self, where, prefix, channel, params: PatchParams): if not self.count: return files = export_multichannel_patches_from_zstack( @@ -78,21 +76,14 @@ class RoiSet(object): self.acc_raw, self.zmask_meta, prefix=prefix, - draw_bounding_box=params.draw_bounding_box, - rescale_clip=params.rescale_clip, make_3d=False, - focus_metric=params.focus_metric, ch_white=channel, - ch_rgb_overlay=params.rgb_overlay_channels, bounding_box_channel=1, bounding_box_linewidth=2, - draw_contour=params.draw_contour, - draw_mask=params.draw_mask, - overlay_gain=params.rgb_overlay_weights, - pad_to=params.pad_to, + **params.__dict__, ) - def export_patch_masks(self, where, prefix, channel, params: RoiSetExportParams): + def export_patch_masks(self, where, prefix, channel): if not self.count: return files = export_patch_masks_from_zstack( @@ -102,12 +93,12 @@ class RoiSet(object): prefix=prefix, ) - def export_annotated_zstack(self, where, prefix, channel, params: RoiSetExportParams): + def export_annotated_zstack(self, where, prefix, channel, params: AnnotatedZStackParams): annotated = InMemoryDataAccessor( draw_boxes_on_3d_image( self.acc_raw.get_one_channel_data(channel).data, self.zmask_meta, - add_label=params.draw_label, + **params.__dict__, ) ) write_accessor_data_to_file( -- GitLab