From 3e7e34c7eb682ea9167c54da1da3b838ef9430aa Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 5 Feb 2024 11:03:47 +0100
Subject: [PATCH] get_patches now returns an extended DataFrame with a patch
 column, so that exporters have access to patch metadata with no risk of
 mixing up indexing

---
 model_server/extensions/chaeo/products.py | 45 ++++++++++++-----------
 model_server/extensions/chaeo/zmask.py    | 10 ++---
 2 files changed, 27 insertions(+), 28 deletions(-)

diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py
index 679271ee..57ff9cbe 100644
--- a/model_server/extensions/chaeo/products.py
+++ b/model_server/extensions/chaeo/products.py
@@ -85,31 +85,20 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask') ->
 
 
 def get_patches_from_zmask_meta(
-        # stack: GenericImageDataAccessor,
-        # zmask_meta: list,
         roiset,
         rescale_clip: float = 0.0,
         pad_to: int = 256,
         make_3d: bool = False,
         focus_metric: str = None,
         **kwargs
-) -> MonoPatchStack:
-    patches = []
+) -> pd.DataFrame:
+    # patches = []
 
     # for mi in zmask_meta:
-    for i, roi in enumerate(roiset.get_df().itertuples()): # TODO: call RoiSet.iter() when implemented
-
-        # sl = roi['slice']
-        # rbb = mi['relative_bounding_box'] # TODO: call rel_ fields in DF
-        # idx = mi['df_index']
-        #
-        # x0 = rbb['x0']
-        # y0 = rbb['y0']
-        # x1 = rbb['x1']
-        # y1 = rbb['y1']
-        #
-        # sp_sl = np.s_[y0: y1, x0: x1, :, :]
+    # dfe = roiset.get_df().assign(patch=object())
+    # for i, roi in enumerate(dfe.itertuples()):
 
+    def _make_patch(roi):
         patch3d = roiset.acc_raw.data[roi.slice]
         ph, pw, pc, pz = patch3d.shape
         subpatch = patch3d[roi.relative_slice]
@@ -173,12 +162,17 @@ def get_patches_from_zmask_meta(
         if pad_to:
             patch = pad(patch, pad_to)
 
-        patches.append(patch)
+        # patches.append(patch)
+        return patch
+        # dfe.loc[i, 'patch'] = patch
+    dfe = roiset.get_df()
+    dfe['patch'] = roiset.get_df().apply(lambda r: _make_patch(r), axis=1)
+    return dfe
 
-    if not make_3d and pc == 1:
-        return MonoPatchStack(patches)
-    else:
-        return Multichannel3dPatchStack(patches)
+    # if not make_3d and pc == 1:
+    #     return MonoPatchStack(patches)
+    # else:
+    #     return Multichannel3dPatchStack(patches)
 
 def export_patches_from_zstack(
         where: Path,
@@ -192,7 +186,7 @@ def export_patches_from_zstack(
         focus_metric: str = None,
         **kwargs
 ):
-    patches_acc = get_patches_from_zmask_meta(
+    patches_df = get_patches_from_zmask_meta(
         roiset,
         # stack,
         # zmask_meta,
@@ -203,6 +197,13 @@ def export_patches_from_zstack(
         **kwargs
     )
 
+    pc = roiset.acc_raw.chroma
+    patches = list(patches_df['patch'])
+    if not make_3d and pc == 1:
+        patches_acc = MonoPatchStack(patches)
+    else:
+        patches_acc = Multichannel3dPatchStack(patches)
+
     exported = []
     for i, roi in enumerate(roiset.get_df().itertuples()):  # just used for label info
     # for i in range(0, len(zmask_meta)):
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index 117e189a..83cf744a 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -80,6 +80,7 @@ class RoiSet(object):
         self.object_id_labels = self.interm['label_map']
         self.object_class_map = None
 
+
     @staticmethod
     def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
         """
@@ -185,11 +186,8 @@ class RoiSet(object):
     def export_patch_masks(self, where, **kwargs) -> list:
         return export_patch_masks(self, where, **kwargs)
 
-    def get_raw_patches(self, channel):
-        return get_patches_from_zmask_meta(
-            self.acc_raw.get_one_channel_data(channel),
-            self.zmask_meta
-        )
+    def get_raw_patches(self, channel):  # tight, un-annotated 2d patches
+        return get_patches_from_zmask_meta(self, pad_to=None)
 
     def get_zmask(self, mask_type='boxes'):
         """
@@ -285,7 +283,7 @@ class RoiSet(object):
                 self.export_patch_masks(subdir, prefix=pr, **params.patch_masks)
             if k == 'annotated_zstacks':
                 annotated = InMemoryDataAccessor(
-                    draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp)
+                    draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) # TODO remove zmask_meta ref
                 )
                 write_accessor_data_to_file(subdir / (pr + '.tif'), annotated)
             if k == 'object_classes':
-- 
GitLab