From 2deca4a19875fe1a81ae96658004be1173710d74 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 4 Apr 2024 10:00:17 +0200
Subject: [PATCH] Harmonize relative paths with exported dataframe, too

---
 model_server/base/roiset.py | 6 ++++--
 tests/test_roiset.py        | 9 +++++++++
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 8d2fe3eb..ce72a89b 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -330,6 +330,7 @@ class RoiSet(object):
             om[self.acc_obj_ids.data == roi.label] = oc
         self.object_class_maps[name] = InMemoryDataAccessor(om)
 
+
     def export_dataframe(self, csv_path: Path) -> str:
         csv_path.parent.mkdir(parents=True, exist_ok=True)
         self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1).to_csv(csv_path, index=False)
@@ -587,10 +588,11 @@ class RoiSet(object):
             pad_to=None,
             expanded=False
         )
-        self._df = self._df.join(df_exp.patch_mask_path)
+        se_pa = df_exp.patch_mask_path.apply(lambda x: str(Path('tight_patch_masks') / x)).rename('tight_patch_masks')
+        self._df = self._df.join(se_pa)
         df_fn = self.export_dataframe(where / 'dataframe' / (prefix + '.csv'))
         record['dataframe'] = str(Path('dataframe') / df_fn)
-        record['tight_patch_masks'] = [str(Path('tight_patch_masks') / fn) for fn in df_exp.patch_mask_path]
+        record['tight_patch_masks'] = list(se_pa)
         return record
 
     @staticmethod
diff --git a/tests/test_roiset.py b/tests/test_roiset.py
index 6ba5c732..da9b43db 100644
--- a/tests/test_roiset.py
+++ b/tests/test_roiset.py
@@ -383,6 +383,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             params=p
         )
 
+        # test on return paths
         for k, v in res.items():
             if isinstance(v, list):
                 for f in v:
@@ -392,6 +393,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
                 self.assertFalse(Path(v).is_absolute())
                 self.assertTrue((where / v).exists())
 
+        # test on paths in CSV
+        test_df = pd.read_csv(where / res['dataframe'])
+        for c in test_df.columns:
+            if '_path' in c:
+                for f in test_df[c]:
+                    self.assertTrue((where / f).exists(), where / f)
+
+
 
 class TestRoiSetFromZmask(unittest.TestCase):
 
-- 
GitLab