From 7ce760c7a246ef47ae2b871b3a06d5bf3acad603 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 3 Feb 2024 22:08:14 +0100
Subject: [PATCH] simplifying export calls

---
 model_server/extensions/chaeo/params.py   |  2 +-
 model_server/extensions/chaeo/products.py | 14 ++++++--------
 model_server/extensions/chaeo/zmask.py    |  5 +----
 3 files changed, 8 insertions(+), 13 deletions(-)

diff --git a/model_server/extensions/chaeo/params.py b/model_server/extensions/chaeo/params.py
index 0880ca12..71b9497d 100644
--- a/model_server/extensions/chaeo/params.py
+++ b/model_server/extensions/chaeo/params.py
@@ -36,7 +36,7 @@ class RoiSetExportParams(BaseModel):
     patches_3d: Union[PatchParams, None] = None
     patches_2d_for_annotation: Union[PatchParams, None] = None
     patches_2d_for_training: Union[PatchParams, None] = None
-    patch_masks: bool = False
+    patch_masks: Union[PatchParams, None] = None
     annotated_zstacks: Union[AnnotatedZStackParams, None] = None
     object_classes: bool = False
     dataframe: bool = False
diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py
index 84535d73..3664ce8e 100644
--- a/model_server/extensions/chaeo/products.py
+++ b/model_server/extensions/chaeo/products.py
@@ -60,9 +60,9 @@ def write_patch_to_file(where, fname, yxcz):
 
 def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
     patches = []
-    for ob in roiset.get_df().itertuples('Roi'):
-        patch = np.zeros((ob.ebb_h, ob.ebb_w, 1, 1), dtype='uint8')
-        patch[ob.relative_slice][:, :, 0, 0] = ob.mask * 255
+    for roi in roiset.get_df().itertuples():
+        patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
+        patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255
 
         if pad_to:
             patch = pad(patch, pad_to)
@@ -72,15 +72,13 @@ def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
 
 
 def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask') -> list:
-    patches_acc = roiset.get_patch_masks(pad_to=pad_to)
+    patches_acc = get_patch_masks(roiset, pad_to=pad_to)
 
     exported = []
-    for i in range(0, roiset.count):
-        mi = roiset.zmask_meta[i]
-        obj = mi['info']
+    for i, roi in enumerate(roiset.get_df().itertuples()):  # assumes index of patches_acc is same as dataframe
         patch = patches_acc.iat_yxcz(i)
         ext = 'png'
-        fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}'
+        fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
         write_patch_to_file(where, fname, patch)
         exported.append(fname)
     return exported
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index b06c276f..781b6cbc 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -181,9 +181,6 @@ class RoiSet(object):
     def get_object_mask_by_class(self, class_id):
         return self.object_id_labels == class_id
 
-    def get_patch_masks(self, **kwargs) -> MonoPatchStack:
-        return get_patch_masks(self, **kwargs)
-
     def export_patch_masks(self, where, **kwargs) -> list:
         return export_patch_masks(self, where, **kwargs)
 
@@ -284,7 +281,7 @@ class RoiSet(object):
                 self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
                 self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1)
             if k == 'patch_masks':
-                self.export_patch_masks(subdir, prefix=pr)
+                self.export_patch_masks(subdir, prefix=pr, **params.patch_masks)
             if k == 'annotated_zstacks':
                 annotated = InMemoryDataAccessor(
                     draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp)
-- 
GitLab