From 8ed6d1fb8345acfc1fda3f8c83b794c0564049de Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 15 Aug 2024 11:23:55 +0200
Subject: [PATCH] Fixed bug where empty RoiSet throws error when calling
 classify_by

---
 model_server/base/roiset.py | 3 +++
 tests/base/test_roiset.py   | 2 ++
 2 files changed, 5 insertions(+)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index c220c215..bccbcb05 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -486,6 +486,9 @@ class RoiSet(object):
         :param object_classification_model: InstanceSegmentation model object
         :return: None
         """
+        if self.count == 0:
+            self._df['classify_by_' + name] = None
+            return True
 
         input_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None)  # all channels
 
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 28885525..44486110 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -81,6 +81,8 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         zero_obmap = InMemoryDataAccessor(np.zeros(self.seg_mask.shape, self.seg_mask.dtype))
         roiset = RoiSet.from_object_ids(self.stack_ch_pa, zero_obmap)
         self.assertEqual(roiset.count, 0)
+        roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel())
+        self.assertTrue('classify_by_dummy_class' in roiset.get_df().columns)
 
     def test_slices_are_valid(self):
         roiset = self._make_roi_set()
-- 
GitLab