diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index bd643df7775d91c556885b51f43f49647ec73b10..b1cedda60dacff5d33b782617900344a22d09f40 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -71,7 +71,7 @@ def write_patch_to_file(where, fname, yxcz): def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack: patches = [] - for roi in roiset.get_df().itertuples(): + for roi in roiset: patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') patch[roi.relative_slice][:, :, 0, 0] = roi.mask * 255 @@ -86,7 +86,7 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask', ** patches_acc = get_patch_masks(roiset, pad_to=pad_to) exported = [] - for i, roi in enumerate(roiset.get_df().itertuples()): # assumes index of patches_acc is same as dataframe + for i, roi in enumerate(roiset): # assumes index of patches_acc is same as dataframe patch = patches_acc.iat_yxcz(i) ext = 'png' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index c07deb807d44afc0c13a0b3fa811eaf0d36725bd..b5a4ffa5be8e8c573200fce580b809b26f79ed83 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -111,8 +111,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_zmask_rel_slices_are_valid(self): roiset = self._make_roi_set() - # for i, s in enumerate(roiset.get_slices()): - for roi in roiset.get_df().itertuples(): + for roi in roiset: ebb = roiset.acc_raw.data[roi.slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index 31d5166b1ecc4ab0a6936ba1975f0b350306e7f8..59567bf8eeb0287aa41a7011dae23d9509f64507 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -80,6 +80,9 @@ class RoiSet(object): self.object_id_labels = self.interm['label_map'] self.object_class_map = {} # classification results + def __iter__(self): + """Expose ROI meta information via the Pandas.DataFrame API""" + return self._df.itertuples(name='Roi') @staticmethod def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: