From 583c415be167c64b1773000ed749ebbada1f1e4f Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 10 Dec 2024 09:33:15 +0100
Subject: [PATCH] Export products no longer throw exceptions when called on
 empty RoiSet

---
 model_server/base/roiset.py | 15 +++++++++++++--
 tests/base/test_roiset.py   |  8 ++++++--
 2 files changed, 19 insertions(+), 4 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 41edb804..8a1ff62a 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -630,7 +630,9 @@ class RoiSet(object):
     def add_df_col(self, name, se: pd.Series) -> None:
         self._df[name] = se
 
-    def get_patches_acc(self, channels: list = None, **kwargs) -> PatchStack:  # padded, un-annotated 2d patches
+    def get_patches_acc(self, channels: list = None, **kwargs) -> Union[PatchStack, None]:  # padded, un-annotated 2d patches
+        if self.count == 0:
+            return None
         if channels and len(channels) == 1:
             return PatchStack(list(self.get_patches(white_channel=channels[0], **kwargs)))
         else:
@@ -857,6 +859,10 @@ class RoiSet(object):
         return patches_df.apply(_export_patch, axis=1)
 
     def get_patch_masks(self, pad_to: int = None, expanded: bool = False, make_3d=True) -> pd.DataFrame:
+
+        if self.count == 0:
+            return pd.DataFrame()
+
         def _make_patch_mask(roi):
             if expanded:
                 patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
@@ -877,7 +883,9 @@ class RoiSet(object):
         dfe['patch_mask'] = dfe.apply(_make_patch_mask, axis=1)
         return dfe
 
-    def get_patch_masks_acc(self, **kwargs) -> PatchStack:
+    def get_patch_masks_acc(self, **kwargs) -> Union[PatchStack, None]:
+        if self.count == 0:
+            return None
         se_pm = self.get_patch_masks(**kwargs).patch_mask
         se_ext = se_pm.apply(lambda x: np.expand_dims(x, 2))
         return PatchStack(list(se_ext))
@@ -896,6 +904,9 @@ class RoiSet(object):
             **kwargs
     ) -> pd.Series:
 
+        if self.count == 0:
+            return pd.Series()
+
         # arrange RGB channels if so specified, otherwise copy roiset.raw_acc data
         raw = self.acc_raw
         if isinstance(rgb_overlay_channels, (list, tuple)) and isinstance(rgb_overlay_weights, (list, tuple)):
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index c134aa9d..943fd6a9 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -871,8 +871,12 @@ class TestEmptyRoiSet(unittest.TestCase):
 
     def test_classify_by(self):
         roiset = self.empty_roiset
-        res = roiset.classify_by('permissive_model', [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.0))
-        self.assertEqual(1, 0)
+        self.assertFalse('classify_by_permissive_model' in roiset.get_df().columns)
+        self.assertTrue(
+            roiset.classify_by('permissive_model', [0], IntensityThresholdInstanceMaskSegmentationModel(tr=0.0))
+        )
+        self.assertEqual(roiset.count, 0)
+        self.assertTrue('classify_by_permissive_model' in roiset.get_df().columns)
 
 class TestRoiSetObjectDetection(unittest.TestCase):
 
-- 
GitLab