diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index e1b1c7f397120a2b53508f941e3ef4e07605079d..b5618abe28313458e845685a0734a323b57dfb16 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -189,9 +189,12 @@ def filter_df_overlap_seg(df: pd.DataFrame) -> pd.DataFrame: sl2 = np.s_[(roi2.y0 - ey0): (roi2.y1 - ey0), (roi2.x0 - ex0): (roi2.x1 - ex0)] emask[sl1] = roi1.binary_mask emask[sl2] = emask[sl2] + roi2.binary_mask - return np.any(emask > 1) + return emask - dfbb['seg_overlaps'] = dfbb.apply(_overlap_seg, axis=1) + emasks = dfbb.apply(_overlap_seg, axis=1) + dfbb['seg_overlaps'] = emasks.apply(lambda x: np.any(x > 1)) + dfbb['seg_intersec'] = emasks.apply(lambda x: (x == 2).sum()) + dfbb['seg_iou'] = emasks.apply(lambda x: (x == 2).sum() / (x > 0).sum()) return dfbb def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 80da0bc9438062cd8f9d48fe3f60759bce8684f8..5d14bfabf5e0c941b008720e76c8d5dafd8cadfb 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -692,4 +692,5 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): }) res = filter_df_overlap_seg(df) - self.assertTrue((res.loc[res.seg_overlaps, :].index == [1, 2]).all()) \ No newline at end of file + self.assertTrue((res.loc[res.seg_overlaps, :].index == [1, 2]).all()) + self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all()) \ No newline at end of file