From 1de55e7c275627169cb54739a3e52339b13e60a8 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 2 Aug 2024 16:30:40 +0200
Subject: [PATCH] filter_df_overlap_bbox now includes both sides of an
 overlapping Roi pair as separate rows

---
 model_server/base/roiset.py | 7 +++++--
 tests/base/test_roiset.py   | 6 +++---
 2 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 11c66d43..e790f7d0 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -122,7 +122,8 @@ def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
 
 def filter_df_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
     """
-    Return subset of DataFrame whose bounding boxes overlap in 3D space
+    Return subset of DataFrame whose bounding boxes overlap in 3D space, with possible duplicate entries where a ROI
+    overlaps with multiple neighbors.
     """
 
     def _compare(r0, r1):
@@ -137,6 +138,8 @@ def filter_df_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
         if _compare(df.iloc[pair[0]], df.iloc[pair[1]]):
             first.append(pair[0])
             second.append(pair[1])
+            first.append(pair[1])
+            second.append(pair[0])
 
     sdf = df.iloc[first]
     sdf['bbox_overlaps_with'] = second
@@ -145,7 +148,7 @@ def filter_df_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
 # TODO: get overlapping segments
 def filter_df_overlap_seg(df: pd.DataFrame) -> pd.DataFrame:
     """
-    Return subset of DataFrame whose segmentations overlap in 3D space
+    Return subset of DataFrame whose segmentations overlap in 3D space.
     """
 
     dfbb = filter_df_overlap_bbox(df)
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index e1c7f800..b63c0aaf 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -650,9 +650,9 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
         })
 
         res = filter_df_overlap_bbox(df)
-        self.assertEqual(len(res), 2)
-        self.assertTrue((res.loc[0, 'bbox_overlaps_with'] == 1).all())
-        self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == 2).all())
+        self.assertEqual(len(res), 4)
+        self.assertTrue((res.loc[0, 'bbox_overlaps_with'] == [1]).all())
+        self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0, 2]).all())
         return res
 
     def test_overlap_seg(self):
-- 
GitLab