From 96a892c871fa231d20002cc0a9f938df3cda87e6 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 29 Nov 2024 12:41:12 +0100
Subject: [PATCH] No longer copy and pass around dataframes when running patch
 exports

---
 model_server/base/roiset.py | 59 +++++++++++++---------------------
 tests/base/test_roiset.py   | 64 +++++++++++++++++++------------------
 2 files changed, 56 insertions(+), 67 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 46a7f114..de61c186 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -625,10 +625,9 @@ class RoiSet(object):
 
     def get_patches_acc(self, channels: list = None, **kwargs) -> PatchStack:  # padded, un-annotated 2d patches
         if channels and len(channels) == 1:
-            patches_df = self.get_patches(white_channel=channels[0], **kwargs)
+            return PatchStack(list(self.get_patches(white_channel=channels[0], **kwargs)))
         else:
-            patches_df = self.get_patches(channels=channels, **kwargs)
-        return PatchStack(list(patches_df.patch))
+            return PatchStack(list(self.get_patches(channels=channels, **kwargs)))
 
     def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> str:
         annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs))
@@ -817,9 +816,18 @@ class RoiSet(object):
         patches_df['patch_mask_path'] = patches_df.apply(_export_patch_mask, axis=1)
         return patches_df
 
-    def export_patches(self, where: Path, prefix='patch', **kwargs) -> pd.DataFrame:
+    def export_patches(self, where: Path, prefix='patch', **kwargs) -> pd.Series:
+        """
+        Export each patch to its own file.
+        :param where: location in which to write patch files
+        :param prefix: prefix of each patch's filename
+        :param kwargs: patch formatting options
+        :return: pd.Series of patch paths
+        """
         make_3d = kwargs.get('make_3d', False)
-        patches_df = self.get_patches(**kwargs).copy()
+        patches_df = self._df.join(
+            self.get_patches(**kwargs).rename('patch')
+        )
 
         def _export_patch(roi):
             patch = InMemoryDataAccessor(roi.patch)
@@ -833,8 +841,7 @@ class RoiSet(object):
                 write_accessor_data_to_file(where / fname, patch)
             return fname
 
-        patches_df['patch_path'] = patches_df.apply(_export_patch, axis=1)
-        return patches_df
+        return patches_df.apply(_export_patch, axis=1)
 
     def get_patch_masks(self, pad_to: int = None, expanded: bool = False) -> pd.DataFrame:
         is_3d = is_df_3d(self._df)
@@ -871,7 +878,7 @@ class RoiSet(object):
             white_channel: int = None,
             expanded=False,
             **kwargs
-    ) -> pd.DataFrame:
+    ) -> pd.Series:
 
         # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
         raw = self.acc_raw
@@ -991,10 +998,7 @@ class RoiSet(object):
                 patch = pad(patch, pad_to)
             return patch
 
-        # TODO: just return needed rows, without DataFrame copy
-        dfe = self._df.copy()
-        dfe['patch'] = dfe.apply(lambda r: _make_patch(r), axis=1)
-        return dfe
+        return self._df.apply(lambda r: _make_patch(r), axis=1)
 
     @property
     def classification_columns(self):
@@ -1033,30 +1037,13 @@ class RoiSet(object):
                     subdir = Path(product_name)
                     if params.write_patches_to_subdirectory:
                         subdir = subdir / prefix
-
-                    df_exports = self.export_patches(where / subdir, prefix=prefix, **pp)
-
-                    df_patch_paths = df_exports.patch_path.apply(lambda x: str(subdir / x))
-                    self._df = self._df.join(df_patch_paths)  # TODO: rename his column
-                    self._df[f'{product_name}_id'] = self._df.apply(lambda _: uuid4(), axis=1)
-                    record[product_name] = list(df_patch_paths)
-            # if k == 'patches_3d':
-            #     df_exp = self.export_patches(
-            #         where / k, prefix=prefix, make_3d=True, **kp
-            #     )
-            #     record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
-            # if k == 'annotated_patches_2d':
-            #     df_exp = self.export_patches(
-            #         where / k, prefix=prefix, make_3d=False, **kp,
-            #     )
-            #     record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
-            # if k == 'patches_2d':
-            #     df_exp = self.export_patches(
-            #         where / k, prefix=prefix, make_3d=False, **kp
-            #     )
-            #     self._df = self._df.join(df_exp.patch_path.apply(lambda x: str(k / x)))
-            #     self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1)
-            #     record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
+                    se_paths = self.export_patches(where / subdir, prefix=prefix, **pp).apply(lambda x: str(subdir / x))
+                    df_patch_info = pd.DataFrame({
+                        f'{product_name}_path': se_paths,
+                        f'{product_name}_id': se_paths.apply(lambda _: uuid4()),
+                    })
+                    self._df = self._df.join(df_patch_info)
+                    record[product_name] = list(se_paths)
             if k == 'annotated_zstacks':
                 record[k] = str(Path(k) / self.export_annotated_zstack(where / k, prefix=prefix, **kp))
             if k == 'object_classes':
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 5640c7f6..e8340d4a 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -112,14 +112,14 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
     def test_make_expanded_2d_patches(self):
         roiset = self._make_roi_set()
         where = output_path / 'expanded_2d_patches'
-        df_res = roiset.export_patches(
+        se_res = roiset.export_patches(
             where,
             draw_bounding_box=True,
             expanded=True,
             pad_to=256,
         )
         df = roiset.get_df()
-        for f in df_res.patch_path:
+        for f in se_res:
             acc = generate_file_accessor(where / f)
             la = int(re.search(r'la([\d]+)', str(f)).group(1))
             roi_q = df.loc[df.label == la, :]
@@ -129,13 +129,13 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
     def test_make_tight_2d_patches(self):
         roiset = self._make_roi_set()
         where = output_path / 'tight_2d_patches'
-        df_res = roiset.export_patches(
+        se_res = roiset.export_patches(
             where,
             draw_bounding_box=True,
             expanded=False
         )
         df = roiset.get_df()
-        for f in df_res.patch_path:  # all exported files are same shape as bounding boxes in RoiSet's datatable
+        for f in se_res:  # all exported files are same shape as bounding boxes in RoiSet's datatable
             acc = generate_file_accessor(where / f)
             la = int(re.search(r'la([\d]+)', str(f)).group(1))
             roi_q = df.loc[df.label == la, :]
@@ -146,13 +146,13 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
     def test_make_expanded_3d_patches(self):
         roiset = self._make_roi_set()
         where = output_path / '3d_patches'
-        df_res = roiset.export_patches(
+        se_res = roiset.export_patches(
             where,
             make_3d=True,
             expanded=True
         )
-        self.assertGreaterEqual(len(df_res), 1)
-        for f in df_res.patch_path:
+        self.assertGreaterEqual(len(se_res), 1)
+        for f in se_res:
             acc = generate_file_accessor(where / f)
             self.assertGreater(acc.nz, 1)
 
@@ -354,14 +354,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
     def test_multichannel_to_mono_2d_patches(self):
         where = output_path / 'multichannel' / 'mono_2d_patches'
-        df_res = self.roiset.export_patches(
+        se_res = self.roiset.export_patches(
             where,
             white_channel=0,
             draw_bounding_box=True,
             expanded=True,
             pad_to=256,
         )
-        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
+        result = generate_file_accessor(where / se_res.iloc[0])
         self.assertEqual(result.chroma, 1)
 
     def test_multichannel_to_color_2d_patches(self):
@@ -371,7 +371,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         patches_acc = self.roiset.get_patches_acc(channels=chs)
         self.assertEqual(patches_acc.chroma, len(chs))
 
-        df_res = self.roiset.export_patches(
+        se_res = self.roiset.export_patches(
             where,
             channels=chs,
             draw_bounding_box=True,
@@ -379,12 +379,12 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             pad_to=256,
             force_tif=True,
         )
-        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
+        result = generate_file_accessor(where / se_res.iloc[0])
         self.assertEqual(result.chroma, len(chs))
 
     def test_multichannnel_to_mono_2d_patches_rgb_bbox(self):
         where = output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox'
-        df_res = self.roiset.export_patches(
+        se_res = self.roiset.export_patches(
             where,
             white_channel=3,
             draw_bounding_box=True,
@@ -392,12 +392,12 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             expanded=True,
             pad_to=256,
         )
-        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
+        result = generate_file_accessor(where / se_res.iloc[0])
         self.assertEqual(result.chroma, 3)
 
     def test_multichannnel_to_rgb_2d_patches_bbox(self):
         where = output_path / 'multichannel' / 'rgb_2d_patches_bbox'
-        df_res = self.roiset.export_patches(
+        se_res = self.roiset.export_patches(
             where,
             white_channel=4,
             rgb_overlay_channels=(3, None, None),
@@ -408,12 +408,12 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             expanded=True,
             pad_to=256,
         )
-        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
+        result = generate_file_accessor(where / se_res.iloc[0])
         self.assertEqual(result.chroma, 3)
 
     def test_multichannnel_to_rgb_2d_patches_mask(self):
         where = output_path / 'multichannel' / 'rgb_2d_patches_mask'
-        df_res = self.roiset.export_patches(
+        se_res = self.roiset.export_patches(
             where,
             white_channel=4,
             rgb_overlay_channels=(3, None, None),
@@ -423,12 +423,12 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             expanded=True,
             pad_to=256,
         )
-        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
+        result = generate_file_accessor(where / se_res.iloc[0])
         self.assertEqual(result.chroma, 3)
 
     def test_multichannnel_to_rgb_2d_patches_contour(self):
         where = output_path / 'multichannel' / 'rgb_2d_patches_contour'
-        df_res = self.roiset.export_patches(
+        se_res = self.roiset.export_patches(
             where,
             rgb_overlay_channels=(3, None, None),
             draw_contour=True,
@@ -437,18 +437,18 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             expanded=True,
             pad_to=256,
         )
-        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
+        result = generate_file_accessor(where / se_res.iloc[0])
         self.assertEqual(result.chroma, 3)
         self.assertEqual(result.get_mono(2).data.max(), 0)  # blue channel is black
 
     def test_multichannel_to_multichannel_tif_patches(self):
         where = output_path / 'multichannel' / 'multichannel_tif_patches'
-        df_res = self.roiset.export_patches(
+        se_res = self.roiset.export_patches(
             where,
             expanded=True,
             pad_to=256,
         )
-        result = generate_file_accessor(where / df_res.patch_path.iloc[0])
+        result = generate_file_accessor(where / se_res.iloc[0])
         self.assertEqual(result.chroma, 5)
         self.assertEqual(result.nz, 1)
 
@@ -556,7 +556,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
         # test on paths in CSV
         test_df = pd.read_csv(where / res['dataframe'])
-        for c in ['tight_patch_masks_path', 'patch_path']:
+        for c in ['tight_patch_masks_path', 'patches_2d_path', 'annotated_patches_2d']:
             self.assertTrue(c in test_df.columns)
             for f in test_df[c]:
                 self.assertTrue((where / f).exists(), where / f)
@@ -564,12 +564,12 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
     def test_get_interm_prods(self):
         p = RoiSetExportParams(**{
             'patches': {
-                'patches_2d': {
+                '2d': {
                     'white_channel': 3,
                     'draw_bounding_box': False,
                     'draw_mask': False,
                 },
-                'annotated_patches_2d': {
+                '2d_annotated': {
                     'white_channel': 3,
                     'draw_bounding_box': True,
                     'rgb_overlay_channels': [3, None, None],
@@ -586,7 +586,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         )
         self.assertNotIn('patches_3d', interm.keys())
         self.assertEqual(
-            interm['annotated_patches_2d'].hw,
+            interm['patches_2d_annotated'].hw,
             (self.roiset.get_df().h.max(), self.roiset.get_df().w.max())
         )
         self.assertEqual(
@@ -605,12 +605,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
     def test_run_export_expanded_2d_patch(self):
         p = RoiSetExportParams(**{
-            'patches_2d': {
-                'white_channel': -1,
-                'draw_bounding_box': False,
-                'draw_mask': False,
-                'expanded': True,
-                'pad_to': 256,
+            'patches': {
+                '2d': {
+                    'white_channel': -1,
+                    'draw_bounding_box': False,
+                    'draw_mask': False,
+                    'expanded': True,
+                    'pad_to': 256,
+                },
             },
         })
         self.assertTrue(hasattr(p.patches['2d'], 'pad_to'))
-- 
GitLab