From 1d5d3c94c6bc0229c01cda50e5ce7da5cc228c58 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 13 Sep 2024 10:13:13 +0200
Subject: [PATCH] Classification transfer returns overlap dataframe for
 debugging purposes

---
 model_server/base/roiset.py | 8 ++++++--
 tests/base/test_roiset.py   | 2 +-
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 7b1ef9af..9e1efc42 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -565,11 +565,12 @@ class RoiSet(object):
 
 
     # TODO: typecheck RoiSet not recognized
-    def get_instance_classification(self, roiset_from, iou_min: float = 0.5):
+    def get_instance_classification(self, roiset_from, iou_min: float = 0.5) -> pd.DataFrame:
         """
         Transfer instance classification labels from another RoiSet based on intersection over union (IOU) similarity
         :param roiset_from: RoiSet source of classification labels, same shape as this RoiSet
         :param iou_min: threshold IOU below which a label is not transferred
+        :return DataFrame of source RoiSet, including overlaps with this RoiSet and IOU metric
         """
         if self.acc_raw.shape != roiset_from.acc_raw.shape:
             raise ShapeMismatchError(
@@ -584,9 +585,10 @@ class RoiSet(object):
             roiset_from.get_df(),
             self.get_df()
         )
+        df_overlaps['transfer'] = df_overlaps.seg_iou > iou_min
         df_merge = pd.merge(
             roiset_from.get_df()[columns],
-            df_overlaps.loc[df_overlaps.seg_iou > iou_min, ['overlaps_with']],
+            df_overlaps.loc[df_overlaps.transfer, ['overlaps_with']],
             left_index=True,
             right_index=True,
             how='inner',
@@ -594,6 +596,8 @@ class RoiSet(object):
         for col in columns:
             self.set_classification(col, df_merge[col])
 
+        return df_overlaps
+
     def get_object_class_map(self, name: str) -> InMemoryDataAccessor:
         """
         For a given classification result, return a map where object IDs are replaced by each object's class
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 53c1f042..75a9b775 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -209,7 +209,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
         self.assertTrue('dummy_class' in roiset1.classification_columns)
         self.assertFalse('dummy_class' in roiset2.classification_columns)
-        roiset2.get_instance_classification(roiset1)
+        res = roiset2.get_instance_classification(roiset1)
         self.assertTrue('dummy_class' in roiset2.classification_columns)
         self.assertLess(
             roiset2.get_df().classify_by_dummy_class.count(),
-- 
GitLab