From 4d11a9cecefe86ac52b2f9a907d9f09f0e8e4b0f Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 7 Feb 2024 13:19:40 +0100 Subject: [PATCH] RoiSet now iterable, i.e. traverse Rois over Pandas API --- model_server/extensions/chaeo/products.py | 4 ++-- model_server/extensions/chaeo/tests/test_zstack.py | 3 +-- model_server/extensions/chaeo/zmask.py | 3 +++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index bd643df7..b1cedda6 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -71,7 +71,7 @@ def write_patch_to_file(where, fname, yxcz): def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack: patches = [] - for roi in roiset.get_df().itertuples(): + for roi in roiset: patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 @@ -86,7 +86,7 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask', ** patches_acc = get_patch_masks(roiset, pad_to=pad_to) exported = [] - for i, roi in enumerate(roiset.get_df().itertuples()): # assumes index of patches_acc is same as dataframe + for i, roi in enumerate(roiset): # assumes index of patches_acc is same as dataframe patch = patches_acc.iat_yxcz(i) ext = 'png' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index c07deb80..b5a4ffa5 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -111,8 +111,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_zmask_rel_slices_are_valid(self): roiset = self._make_roi_set() - # for i, s in enumerate(roiset.get_slices()): - for roi in roiset.get_df().itertuples(): + for roi in roiset: ebb = roiset.acc_raw.data[roi.slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index 31d5166b..59567bf8 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -80,6 +80,9 @@ class RoiSet(object): self.object_id_labels = self.interm['label_map'] self.object_class_map = {} # classification results + def __iter__(self): + """Expose ROI meta information via the Pandas.DataFrame API""" + return self._df.itertuples(name='Roi') @staticmethod def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: -- GitLab