diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index e43be1f271eeb9ce4d82893589bab4d36fc52053..998b9e635dcb6dbfd2fd8892f98afe053e7d1871 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1,4 +1,4 @@ -from itertools import combinations +import itertools from math import sqrt, floor from pathlib import Path from typing import List, Union @@ -120,10 +120,14 @@ def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: return df.loc[df.query(query_str).index, :] -def filter_df_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame: +def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame: """ - Return subset of DataFrame whose bounding boxes overlap in 3D space, with possible duplicate entries where a ROI + If passed a single DataFrame, return the subset whose bounding boxes overlap in 3D space. If passed two DataFrames, + return the subset where a ROI in the first overlaps a ROI in the second. May return duplicates entries where a ROI overlaps with multiple neighbors. + :param df1: DataFrame with potentially overlapping bounding boxes + :param df2: (optional) second DataFrame + :return DataFrame describing subset of overlapping ROIs """ def _compare(r0, r1): @@ -134,14 +138,21 @@ def filter_df_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame: first = [] second = [] - for pair in combinations(df.index, 2): - if _compare(df.iloc[pair[0]], df.iloc[pair[1]]): - first.append(pair[0]) - second.append(pair[1]) - first.append(pair[1]) - second.append(pair[0]) - - sdf = df.iloc[first] + + if df2 is not None: + for pair in itertools.product(df1.index, df2.index): + if _compare(df1.iloc[pair[0]], df2.iloc[pair[1]]): + first.append(pair[0]) + second.append(pair[1]) + else: + for pair in itertools.combinations(df1.index, 2): + if _compare(df1.iloc[pair[0]], df1.iloc[pair[1]]): + first.append(pair[0]) + second.append(pair[1]) + first.append(pair[1]) + second.append(pair[0]) + + sdf = df1.iloc[first] sdf.loc[:, 'bbox_overlaps_with'] = second return sdf diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 1dd92061df5554b60643b8b0727a1992b9c90725..91eda9bc171622a15e889d825bb0f66257103c2a 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -664,7 +664,7 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): }) res = filter_df_overlap_bbox(df1, df2) self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0]).all()) - self.assertTrue(0) + self.assertEqual(len(res), 1) def test_overlap_seg(self): df = pd.DataFrame({