From 0d0a86bf7f6aa27913ae20384c40f327cd1f2f3b Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 2 Apr 2024 18:49:30 +0200
Subject: [PATCH] Force export of dataset and patches; just need to override
 bounding box expansion

---
 model_server/base/roiset.py | 24 ++++++++++--------
 tests/test_roiset.py        | 50 +++++++++++++++++++++++++------------
 2 files changed, 47 insertions(+), 27 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 29a897df..a5aaf763 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -56,11 +56,8 @@ class RoiSetExportParams(BaseModel):
     patches_3d: Union[PatchParams, None] = None
     annotated_patches_2d: Union[PatchParams, None] = None
     patches_2d: Union[PatchParams, None] = None
-    patch_masks: Union[PatchParams, None] = None
     annotated_zstacks: Union[AnnotatedZStackParams, None] = None
     object_classes: bool = False
-    dataframe: bool = False
-
 
 
 
@@ -350,7 +347,13 @@ class RoiSet(object):
             om[self.acc_obj_ids.data == roi.label] = oc
         self.object_class_maps[name] = InMemoryDataAccessor(om)
 
-    def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list:
+    def export_dataframe(self, csv_path: Path):
+        csv_path.parent.mkdir(parents=True, exist_ok=True)
+        self._df.drop(['slice', 'relative_slice', 'mask'], axis=1).to_csv(csv_path, index=False)
+        return csv_path
+
+
+    def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask') -> list:
         patches_acc = self.get_patch_masks(pad_to=pad_to)
 
         exported = []
@@ -538,6 +541,11 @@ class RoiSet(object):
         if not self.count:
             return
         raw_ch = self.acc_raw.get_one_channel_data(channel)
+
+        # export dataframe and patch masks
+        record['dataframe'] = self.export_dataframe(where / 'dataframe' / (prefix + '.csv'))
+        self.export_patch_masks(where / 'patch_masks', prefix=prefix, pad_to=None)
+
         for k in params.dict().keys():
             subdir = where / k
             pr = prefix
@@ -560,8 +568,6 @@ class RoiSet(object):
                 df_patches = pd.DataFrame(files)
                 self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
                 self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1)
-            if k == 'patch_masks':
-                self.export_patch_masks(subdir, prefix=pr, **kp)
             if k == 'annotated_zstacks':
                 self.export_annotated_zstack(subdir, prefix=pr, **kp)
             if k == 'object_classes':
@@ -569,11 +575,7 @@ class RoiSet(object):
                     fp = subdir / kc / (pr + '.tif')
                     write_accessor_data_to_file(fp, acc)
                     record[f'{k}_{kc}'] = fp
-            if k == 'dataframe':
-                dfpa = subdir / (pr + '.csv')
-                dfpa.parent.mkdir(parents=True, exist_ok=True)
-                self._df.drop(['slice', 'relative_slice', 'mask'], axis=1).to_csv(dfpa, index=False)
-                record[k] = dfpa
+
         return record
 
 
diff --git a/tests/test_roiset.py b/tests/test_roiset.py
index 596b66d0..b5e9175a 100644
--- a/tests/test_roiset.py
+++ b/tests/test_roiset.py
@@ -1,3 +1,5 @@
+import os
+import re
 import unittest
 
 import numpy as np
@@ -333,21 +335,37 @@ class TestRoiSetFromZmask(unittest.TestCase):
 
         df_test = pd.read_csv(where_df)
 
-        # zmask = np.zeros((*self.stack.hw, 1, self.stack.nz), dtype=bool)
-        print('hi')
-
-        fn = output_path / 'roiset_from_3d' / 'patch_masks' / 'ref-la{:04d}-zi{:04d}.png'
-        patch_masks = {}
-
-        def _label_obj(r):
-            sl = np.s_[r.ebb_y0:r.ebb_y1, r.ebb_x0:r.ebb_x1, :, r.zi:r.zi + 1]
-            self.assertEqual(str(sl), r.slice)
-            patch_masks[r.label] = generate_file_accessor(str(fn).format(r.label, r.zi)).data
-            # zmask[sl] = True
-
-        df_test.apply(lambda x: _label_obj(x), axis=1)
-
-        roiset_test = RoiSet.from_df_and_patch_masks(self.stack, df_test, patch_masks)
-        print('')
+        # check that patches are correct size
+        where_patch_masks = output_path / 'roiset_from_3d' / 'patch_masks'
+        for pmf in where_patch_masks.iterdir():
+            self.assertTrue(pmf.suffix.upper() == '.PNG')
+            la = int(re.search(r'la([\d]+)', str(pmf)).group(1))
+            roi_q = df_test.loc[df_test.label == la, :]
+            self.assertEqual(len(roi_q), 1)
+            roi = roi_q.iloc[0]
+            h = int(roi.y1 - roi.y0)
+            w = int(roi.x1 - roi.x0)
+            m_acc = generate_file_accessor(pmf)
+            self.assertEqual((h, w), m_acc.hw)
+
+
+        # df_test = pd.read_csv(where_df)
+        #
+        # # zmask = np.zeros((*self.stack.hw, 1, self.stack.nz), dtype=bool)
+        # print('hi')
+        #
+        # fn = output_path / 'roiset_from_3d' / 'patch_masks' / 'ref-la{:04d}-zi{:04d}.png'
+        # patch_masks = {}
+        #
+        # def _label_obj(r):
+        #     sl = np.s_[r.ebb_y0:r.ebb_y1, r.ebb_x0:r.ebb_x1, :, r.zi:r.zi + 1]
+        #     self.assertEqual(str(sl), r.slice)
+        #     patch_masks[r.label] = generate_file_accessor(str(fn).format(r.label, r.zi)).data
+        #     # zmask[sl] = True
+        #
+        # df_test.apply(lambda x: _label_obj(x), axis=1)
+        #
+        # roiset_test = RoiSet.from_df_and_patch_masks(self.stack, df_test, patch_masks)
+        # print('')
 
 
-- 
GitLab