From 83259e776a8c752a0da774a68cf61b72e1ab6e92 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 15 Jul 2024 17:40:22 +0200
Subject: [PATCH] Implemented logic to find overlapping bounding boxes

---
 model_server/base/roiset.py | 34 +++++++++++++++++++++++++++++++++-
 tests/base/test_roiset.py   | 20 +++++++++++++++++++-
 2 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 1e39496b..360d86f9 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -1,6 +1,6 @@
+from itertools import combinations
 from math import sqrt, floor
 from pathlib import Path
-import re
 from typing import List, Union
 from uuid import uuid4
 
@@ -108,6 +108,26 @@ def _focus_metrics():
         'moment': lambda x: moment(x.flatten(), moment=2),
     }
 
+# TODO: get overlapping bounding boxes
+def _filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
+
+    def _compare(r0, r1):
+        olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0)
+        oly = (r0.y0 < r1.y1) and (r0.y1 > r1.y0)
+        olz = (r0.zi == r1.zi)
+        return olx and oly and olz
+
+    first = []
+    second = []
+    for pair in combinations(df.index, 2):
+        if _compare(df.iloc[pair[0]], df.iloc[pair[1]]):
+            first.append(pair[0])
+            second.append(pair[1])
+
+    sdf = df.iloc[first]
+    sdf['overlaps_with'] = second
+    return sdf
+
 
 def _safe_add(a, g, b):
     assert a.dtype == b.dtype
@@ -178,6 +198,10 @@ class RoiSet(object):
         """
         return RoiSet(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params)
 
+
+    # TODO: generate overlapping RoiSet from multiple masks
+    # call e.g. static adder
+
     @staticmethod
     def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
         """
@@ -258,6 +282,14 @@ class RoiSet(object):
         )
         return df
 
+
+    # TODO: get overlapping segments
+    def get_overlap_seg(self) -> pd.DataFrame:
+        df_overlap_bbox = self.get_overlap_bbox()
+
+
+    # TODO: test if overlaps exist
+
     @staticmethod
     def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
         query_str = 'label > 0'  # always true
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index beff2a73..6f39f239 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -7,7 +7,7 @@ from pathlib import Path
 
 import pandas as pd
 
-from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams
+from model_server.base.roiset import _filter_overlap_bbox, RoiSetExportParams, RoiSetMetaParams
 from model_server.base.roiset import RoiSet
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
 from model_server.base.models import DummyInstanceSegmentationModel
@@ -17,6 +17,24 @@ data = conf.meta['image_files']
 output_path = conf.meta['output_path']
 params = conf.meta['roiset']
 
+class TestOverlapLogic(unittest.TestCase):
+
+    def test_overlap_bbox(self):
+        df = pd.DataFrame({
+            'x0': [0, 1, 3, 1, 1],
+            'x1': [2, 3, 4, 3, 3],
+            'y0': [0, 0, 0, 1, 0],
+            'y1': [1, 1, 1, 2, 1],
+            'zi': [0, 0, 0, 0, 1],
+        })
+
+        res = _filter_overlap_bbox(df)
+        print(res)
+
+        self.assertEqual(len(res), 1)
+        self.assertTrue((res.loc[0, 'overlaps_with'] == 1).all())
+
+
 class BaseTestRoiSetMonoProducts(object):
 
     def setUp(self) -> None:
-- 
GitLab