From 01adb4b2372fda07c0c75ff89a8cc0dc5e885b20 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 21 Dec 2023 12:03:09 +0100
Subject: [PATCH] Pass params model straight to product exporters, so less code
 dedicated to matching parameters

---
 extensions/chaeo/annotators.py |  2 +-
 extensions/chaeo/products.py   | 14 +++++++-------
 extensions/chaeo/workflows.py  |  1 -
 extensions/chaeo/zmask.py      | 29 ++++++++++-------------------
 4 files changed, 18 insertions(+), 28 deletions(-)

diff --git a/extensions/chaeo/annotators.py b/extensions/chaeo/annotators.py
index eb4996c3..687b4ca1 100644
--- a/extensions/chaeo/annotators.py
+++ b/extensions/chaeo/annotators.py
@@ -23,7 +23,7 @@ def draw_boxes_on_2d_image(yx_img, boxes, **kwargs):
 
         draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth)
 
-        if kwargs.get('add_label') is True:
+        if kwargs.get('draw_label') is True:
             draw.text((xm, y0), f'{la:04d}', fill='white', anchor='mb')
 
     return pilimg
diff --git a/extensions/chaeo/products.py b/extensions/chaeo/products.py
index 0f1ac976..155a544b 100644
--- a/extensions/chaeo/products.py
+++ b/extensions/chaeo/products.py
@@ -331,8 +331,8 @@ def export_multichannel_patches_from_zstack(
     where: Path,
     stack: GenericImageDataAccessor,
     zmask_meta: list,
-    ch_rgb_overlay: tuple = None,
-    overlay_gain: tuple = (1.0, 1.0, 1.0),
+    rgb_overlay_channels: list = None,
+    rgb_overlay_weights: list = [1.0, 1.0, 1.0],
     ch_white: int = None,
     **kwargs
 ):
@@ -360,17 +360,17 @@ def export_multichannel_patches_from_zstack(
     else:
         mdata = idata
 
-    if ch_rgb_overlay:
-        assert len(ch_rgb_overlay) == 3
-        assert len(overlay_gain) == 3
-        for ii, ci in enumerate(ch_rgb_overlay):
+    if rgb_overlay_channels:
+        assert len(rgb_overlay_channels) == 3
+        assert len(rgb_overlay_weights) == 3
+        for ii, ci in enumerate(rgb_overlay_channels):
             if ci is None:
                 continue
             assert isinstance(ci, int)
             assert ci < stack.chroma
             mdata[:, :, ii, :] = _safe_add(
                 mdata[:, :, ii, :],
-                overlay_gain[ii],
+                rgb_overlay_weights[ii],
                 idata[:, :, ci, :]
             )
 
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index 8dd7554f..811aed5f 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -122,7 +122,6 @@ def infer_object_map_from_zstack(
             Path(output_folder_path) / 'patch_masks',
             fstem,
             patches_channel,
-            exports.patch_masks
         )
 
     if exports.annotated_z_stack:
diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py
index 76e96bf8..77854e3e 100644
--- a/extensions/chaeo/zmask.py
+++ b/extensions/chaeo/zmask.py
@@ -9,7 +9,7 @@ from sklearn.linear_model import LinearRegression
 
 from extensions.chaeo.annotators import draw_boxes_on_3d_image
 from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
-from extensions.chaeo.params import RoiFilter, RoiSetExportParams
+from extensions.chaeo.params import AnnotatedZStackParams, PatchParams, RoiFilter, RoiSetExportParams
 from extensions.chaeo.process import mask_largest_object
 from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.models import InstanceSegmentationModel
@@ -40,7 +40,7 @@ class RoiSet(object):
     def get_argmax(self):
         return self.interm.argmax
 
-    def export_3d_patches(self, where, prefix, channel, params: RoiSetExportParams):
+    def export_3d_patches(self, where, prefix, channel, params: PatchParams):
         if not self.count:
             return
         files = export_patches_from_zstack(
@@ -48,12 +48,11 @@ class RoiSet(object):
             self.acc_raw.get_one_channel_data(channel),
             self.zmask_meta,
             prefix=prefix,
-            draw_bounding_box=params.draw_bounding_box,
-            rescale_clip=params.rescale_clip,
             make_3d=True,
+            **params.__dict__,
         )
 
-    def export_2d_patches_for_training(self, where, prefix, channel, params: RoiSetExportParams):
+    def export_2d_patches_for_training(self, where, prefix, channel, params: PatchParams):
         if not self.count:
             return
         files = export_multichannel_patches_from_zstack(
@@ -62,15 +61,14 @@ class RoiSet(object):
             self.zmask_meta,
             ch_white=channel,
             prefix=prefix,
-            rescale_clip=params.rescale_clip,
             make_3d=False,
-            focus_metric=params.focus_metric,
+            **params.__dict__,
         )
         df_patches = pd.DataFrame(files)
         self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
         self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1)
 
-    def export_2d_patches_for_annotation(self, where, prefix, channel, params: RoiSetExportParams):
+    def export_2d_patches_for_annotation(self, where, prefix, channel, params: PatchParams):
         if not self.count:
             return
         files = export_multichannel_patches_from_zstack(
@@ -78,21 +76,14 @@ class RoiSet(object):
             self.acc_raw,
             self.zmask_meta,
             prefix=prefix,
-            draw_bounding_box=params.draw_bounding_box,
-            rescale_clip=params.rescale_clip,
             make_3d=False,
-            focus_metric=params.focus_metric,
             ch_white=channel,
-            ch_rgb_overlay=params.rgb_overlay_channels,
             bounding_box_channel=1,
             bounding_box_linewidth=2,
-            draw_contour=params.draw_contour,
-            draw_mask=params.draw_mask,
-            overlay_gain=params.rgb_overlay_weights,
-            pad_to=params.pad_to,
+            **params.__dict__,
         )
 
-    def export_patch_masks(self, where, prefix, channel, params: RoiSetExportParams):
+    def export_patch_masks(self, where, prefix, channel):
         if not self.count:
             return
         files = export_patch_masks_from_zstack(
@@ -102,12 +93,12 @@ class RoiSet(object):
             prefix=prefix,
         )
 
-    def export_annotated_zstack(self, where, prefix, channel, params: RoiSetExportParams):
+    def export_annotated_zstack(self, where, prefix, channel, params: AnnotatedZStackParams):
         annotated = InMemoryDataAccessor(
             draw_boxes_on_3d_image(
                 self.acc_raw.get_one_channel_data(channel).data,
                 self.zmask_meta,
-                add_label=params.draw_label,
+                **params.__dict__,
             )
         )
         write_accessor_data_to_file(
-- 
GitLab