From 017e52cebd3d95de9846e869f574efeaf27d79a3 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 4 Apr 2024 10:27:49 +0200
Subject: [PATCH] Serialize after running other exports, as patch path is
 embedded in dataframe; write this with subdirectory

---
 model_server/base/roiset.py | 8 ++++----
 tests/test_roiset.py        | 8 ++++----
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 697f22a3..59ed23b3 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -537,9 +537,6 @@ class RoiSet(object):
         if not self.count:
             return
 
-        # export dataframe and patch masks
-        record = self.serialize(where, prefix=prefix)
-
         for k in params.dict().keys():
             subdir = where / k
             pr = prefix
@@ -561,7 +558,7 @@ class RoiSet(object):
                 df_exp = self.export_patches(
                     subdir, white_channel=channel, prefix=pr, make_3d=False, **kp
                 )
-                self._df = self._df.join(df_exp.patch_path)
+                self._df = self._df.join(df_exp.patch_path.apply(lambda x: str(Path('patches_2d') / x)))
                 self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1)
                 record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
             if k == 'annotated_zstacks':
@@ -572,6 +569,9 @@ class RoiSet(object):
                     write_accessor_data_to_file(fp, acc)
                     record[f'{k}_{kc}'] = str(fp)
 
+        # export dataframe and patch masks
+        record = {**record, **self.serialize(where, prefix=prefix)}
+
         return record
 
     def serialize(self, where: Path, prefix='') -> dict:
diff --git a/tests/test_roiset.py b/tests/test_roiset.py
index da9b43db..1d2b972b 100644
--- a/tests/test_roiset.py
+++ b/tests/test_roiset.py
@@ -395,10 +395,10 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
         # test on paths in CSV
         test_df = pd.read_csv(where / res['dataframe'])
-        for c in test_df.columns:
-            if '_path' in c:
-                for f in test_df[c]:
-                    self.assertTrue((where / f).exists(), where / f)
+        for c in ['tight_patch_masks_path', 'patch_path']:
+            self.assertTrue(c in test_df.columns)
+            for f in test_df[c]:
+                self.assertTrue((where / f).exists(), where / f)
 
 
 
-- 
GitLab