From 2deca4a19875fe1a81ae96658004be1173710d74 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 4 Apr 2024 10:00:17 +0200 Subject: [PATCH] Harmonize relative paths with exported dataframe, too --- model_server/base/roiset.py | 6 ++++-- tests/test_roiset.py | 9 +++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 8d2fe3eb..ce72a89b 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -330,6 +330,7 @@ class RoiSet(object): om[self.acc_obj_ids.data == roi.label] = oc self.object_class_maps[name] = InMemoryDataAccessor(om) + def export_dataframe(self, csv_path: Path) -> str: csv_path.parent.mkdir(parents=True, exist_ok=True) self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1).to_csv(csv_path, index=False) @@ -587,10 +588,11 @@ class RoiSet(object): pad_to=None, expanded=False ) - self._df = self._df.join(df_exp.patch_mask_path) + se_pa = df_exp.patch_mask_path.apply(lambda x: str(Path('tight_patch_masks') / x)).rename('tight_patch_masks') + self._df = self._df.join(se_pa) df_fn = self.export_dataframe(where / 'dataframe' / (prefix + '.csv')) record['dataframe'] = str(Path('dataframe') / df_fn) - record['tight_patch_masks'] = [str(Path('tight_patch_masks') / fn) for fn in df_exp.patch_mask_path] + record['tight_patch_masks'] = list(se_pa) return record @staticmethod diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 6ba5c732..da9b43db 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -383,6 +383,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa params=p ) + # test on return paths for k, v in res.items(): if isinstance(v, list): for f in v: @@ -392,6 +393,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertFalse(Path(v).is_absolute()) self.assertTrue((where / v).exists()) + # test on paths in CSV + test_df = pd.read_csv(where / res['dataframe']) + for c in test_df.columns: + if '_path' in c: + for f in test_df[c]: + self.assertTrue((where / f).exists(), where / f) + + class TestRoiSetFromZmask(unittest.TestCase): -- GitLab