From f7037423fb588ff4465d156a077e02c4afd7fd56 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 6 Sep 2024 08:56:31 -0700
Subject: [PATCH] Successful test of multi-set bbox overlap

---
 model_server/base/roiset.py | 33 ++++++++++++++++++++++-----------
 tests/base/test_roiset.py   |  2 +-
 2 files changed, 23 insertions(+), 12 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index e43be1f2..998b9e63 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -1,4 +1,4 @@
-from itertools import combinations
+import itertools
 from math import sqrt, floor
 from pathlib import Path
 from typing import List, Union
@@ -120,10 +120,14 @@ def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
     return df.loc[df.query(query_str).index, :]
 
 
-def filter_df_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
+def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame:
     """
-    Return subset of DataFrame whose bounding boxes overlap in 3D space, with possible duplicate entries where a ROI
+    If passed a single DataFrame, return the subset whose bounding boxes overlap in 3D space.  If passed two DataFrames,
+    return the subset where a ROI in the first overlaps a ROI in the second.  May return duplicates entries where a ROI
     overlaps with multiple neighbors.
+    :param df1: DataFrame with potentially overlapping bounding boxes
+    :param df2: (optional) second DataFrame
+    :return DataFrame describing subset of overlapping ROIs
     """
 
     def _compare(r0, r1):
@@ -134,14 +138,21 @@ def filter_df_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
 
     first = []
     second = []
-    for pair in combinations(df.index, 2):
-        if _compare(df.iloc[pair[0]], df.iloc[pair[1]]):
-            first.append(pair[0])
-            second.append(pair[1])
-            first.append(pair[1])
-            second.append(pair[0])
-
-    sdf = df.iloc[first]
+
+    if df2 is not None:
+        for pair in itertools.product(df1.index, df2.index):
+            if _compare(df1.iloc[pair[0]], df2.iloc[pair[1]]):
+                first.append(pair[0])
+                second.append(pair[1])
+    else:
+        for pair in itertools.combinations(df1.index, 2):
+            if _compare(df1.iloc[pair[0]], df1.iloc[pair[1]]):
+                first.append(pair[0])
+                second.append(pair[1])
+                first.append(pair[1])
+                second.append(pair[0])
+
+    sdf = df1.iloc[first]
     sdf.loc[:, 'bbox_overlaps_with'] = second
     return sdf
 
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 1dd92061..91eda9bc 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -664,7 +664,7 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
         })
         res = filter_df_overlap_bbox(df1, df2)
         self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0]).all())
-        self.assertTrue(0)
+        self.assertEqual(len(res), 1)
 
     def test_overlap_seg(self):
         df = pd.DataFrame({
-- 
GitLab