From 940b79410abfa86f13fce1ea706e8f4ae622d0f9 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 12 Sep 2024 13:03:36 +0200
Subject: [PATCH] Tested segmentation from two dataframes

---
 model_server/base/roiset.py | 24 +++++++++++++++------
 tests/base/test_roiset.py   | 43 +++++++++++++++++++++++++++++++++++--
 2 files changed, 59 insertions(+), 8 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index b5618abe..fbda1cf1 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -169,17 +169,28 @@ def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Da
     sdf.loc[:, 'bbox_intersec'] = intersec
     return sdf
 
-# TODO: option to quantify overlap e.g. IOU
-def filter_df_overlap_seg(df: pd.DataFrame) -> pd.DataFrame:
+
+def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame:
     """
-    Return subset of DataFrame whose segmentations overlap in 3D space.
+    If passed a single DataFrame, return the subset whose segmentations overlap in 3D space.  If passed two DataFrames,
+    return the subset where a ROI in the first overlaps a ROI in the second.  May return duplicates entries where a ROI
+    overlaps with multiple neighbors.
+    :param df1: DataFrame with potentially overlapping bounding boxes
+    :param df2: (optional) second DataFrame
+    :return DataFrame describing subset of overlapping ROIs
+        seg_overlaps_with: index of ROI that overlaps
+        seg_intersec: pixel area of intersecting region
+        seg_iou: intersection over union
     """
 
-    dfbb = filter_df_overlap_bbox(df)
+    dfbb = filter_df_overlap_bbox(df1, df2)
 
     def _overlap_seg(r):
-        roi1 = df.loc[r.name]
-        roi2 = df.loc[r.bbox_overlaps_with]
+        roi1 = df1.loc[r.name]
+        if df2 is not None:
+            roi2 = df2.loc[r.bbox_overlaps_with]
+        else:
+            roi2 = df1.loc[r.bbox_overlaps_with]
         ex0 = min(roi1.x0, roi2.x0, roi1.x1, roi2.x1)
         ew = max(roi1.x0, roi2.x0, roi1.x1, roi2.x1) - ex0
         ey0 = min(roi1.y0, roi2.y0, roi1.y1, roi2.y1)
@@ -197,6 +208,7 @@ def filter_df_overlap_seg(df: pd.DataFrame) -> pd.DataFrame:
     dfbb['seg_iou'] = emasks.apply(lambda x: (x == 2).sum() / (x > 0).sum())
     return dfbb
 
+
 def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
     """
     Build dataframe associate object IDs with summary stats
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 5d14bfab..c4d98df8 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -647,7 +647,7 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertTrue((res.bbox_intersec == 2).all())
         return res
 
-    # TODO: test overloaded condition comparing two DFs
+
     def test_overlap_bbox_multiple(self):
         df1 = pd.DataFrame({
             'x0': [0, 1],
@@ -668,6 +668,7 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertEqual(len(res), 1)
         self.assertTrue((res.bbox_intersec == 2).all())
 
+
     def test_overlap_seg(self):
         df = pd.DataFrame({
             'x0': [0, 1, 2],
@@ -693,4 +694,42 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
         res = filter_df_overlap_seg(df)
         self.assertTrue((res.loc[res.seg_overlaps, :].index == [1, 2]).all())
-        self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all())
\ No newline at end of file
+        self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all())
+
+    def test_overlap_seg_multiple(self):
+        df1 = pd.DataFrame({
+            'x0': [0, 1],
+            'x1': [2, 3],
+            'y0': [0, 0],
+            'y1': [2, 2],
+            'zi': [0, 0],
+            'binary_mask': [
+                [
+                    [1, 1],
+                    [1, 0]
+                ],
+                [
+                    [0, 1],
+                    [1, 1]
+                ],
+            ]
+        })
+        df2 = pd.DataFrame({
+            'x0': [2],
+            'x1': [4],
+            'y0': [0],
+            'y1': [2],
+            'zi': [0],
+            'binary_mask': [
+                [
+                    [1, 1],
+                    [1, 1]
+                ],
+            ]
+        })
+        res = filter_df_overlap_seg(df1, df2)
+        self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0]).all())
+        self.assertEqual(len(res), 1)
+        self.assertTrue((res.bbox_intersec == 2).all())
+        self.assertTrue((res.loc[res.seg_overlaps, :].index == [1]).all())
+        self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all())
-- 
GitLab