From 4d11a9cecefe86ac52b2f9a907d9f09f0e8e4b0f Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 7 Feb 2024 13:19:40 +0100
Subject: [PATCH] RoiSet now iterable, i.e. traverse Rois over Pandas API

---
 model_server/extensions/chaeo/products.py          | 4 ++--
 model_server/extensions/chaeo/tests/test_zstack.py | 3 +--
 model_server/extensions/chaeo/zmask.py             | 3 +++
 3 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py
index bd643df7..b1cedda6 100644
--- a/model_server/extensions/chaeo/products.py
+++ b/model_server/extensions/chaeo/products.py
@@ -71,7 +71,7 @@ def write_patch_to_file(where, fname, yxcz):
 
 def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
     patches = []
-    for roi in roiset.get_df().itertuples():
+    for roi in roiset:
         patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
         patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255
 
@@ -86,7 +86,7 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask', **
     patches_acc = get_patch_masks(roiset, pad_to=pad_to)
 
     exported = []
-    for i, roi in enumerate(roiset.get_df().itertuples()):  # assumes index of patches_acc is same as dataframe
+    for i, roi in enumerate(roiset):  # assumes index of patches_acc is same as dataframe
         patch = patches_acc.iat_yxcz(i)
         ext = 'png'
         fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py
index c07deb80..b5a4ffa5 100644
--- a/model_server/extensions/chaeo/tests/test_zstack.py
+++ b/model_server/extensions/chaeo/tests/test_zstack.py
@@ -111,8 +111,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def test_zmask_rel_slices_are_valid(self):
         roiset = self._make_roi_set()
-        # for i, s in enumerate(roiset.get_slices()):
-        for roi in roiset.get_df().itertuples():
+        for roi in roiset:
             ebb = roiset.acc_raw.data[roi.slice]
             self.assertEqual(len(ebb.shape), 4)
             self.assertTrue(np.all([si >= 1 for si in ebb.shape]))
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index 31d5166b..59567bf8 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -80,6 +80,9 @@ class RoiSet(object):
         self.object_id_labels = self.interm['label_map']
         self.object_class_map = {}  # classification results
 
+    def __iter__(self):
+        """Expose ROI meta information via the Pandas.DataFrame API"""
+        return self._df.itertuples(name='Roi')
 
     @staticmethod
     def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
-- 
GitLab