From 19261090ca52d37b8f00c79ebed96ae9a1d347ee Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 12 Sep 2024 12:00:39 +0200
Subject: [PATCH] Implemented segmentation IOU

---
 model_server/base/roiset.py | 7 +++++--
 tests/base/test_roiset.py   | 3 ++-
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index e1b1c7f3..b5618abe 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -189,9 +189,12 @@ def filter_df_overlap_seg(df: pd.DataFrame) -> pd.DataFrame:
         sl2 = np.s_[(roi2.y0 - ey0): (roi2.y1 - ey0), (roi2.x0 - ex0): (roi2.x1 - ex0)]
         emask[sl1] = roi1.binary_mask
         emask[sl2] = emask[sl2] + roi2.binary_mask
-        return np.any(emask > 1)
+        return emask
 
-    dfbb['seg_overlaps'] = dfbb.apply(_overlap_seg, axis=1)
+    emasks = dfbb.apply(_overlap_seg, axis=1)
+    dfbb['seg_overlaps'] = emasks.apply(lambda x: np.any(x > 1))
+    dfbb['seg_intersec'] = emasks.apply(lambda x: (x == 2).sum())
+    dfbb['seg_iou'] = emasks.apply(lambda x: (x == 2).sum() / (x > 0).sum())
     return dfbb
 
 def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 80da0bc9..5d14bfab 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -692,4 +692,5 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
         })
 
         res = filter_df_overlap_seg(df)
-        self.assertTrue((res.loc[res.seg_overlaps, :].index == [1, 2]).all())
\ No newline at end of file
+        self.assertTrue((res.loc[res.seg_overlaps, :].index == [1, 2]).all())
+        self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all())
\ No newline at end of file
-- 
GitLab