From 1d5d3c94c6bc0229c01cda50e5ce7da5cc228c58 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 13 Sep 2024 10:13:13 +0200 Subject: [PATCH] Classification transfer returns overlap dataframe for debugging purposes --- model_server/base/roiset.py | 8 ++++++-- tests/base/test_roiset.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 7b1ef9af..9e1efc42 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -565,11 +565,12 @@ class RoiSet(object): # TODO: typecheck RoiSet not recognized - def get_instance_classification(self, roiset_from, iou_min: float = 0.5): + def get_instance_classification(self, roiset_from, iou_min: float = 0.5) -> pd.DataFrame: """ Transfer instance classification labels from another RoiSet based on intersection over union (IOU) similarity :param roiset_from: RoiSet source of classification labels, same shape as this RoiSet :param iou_min: threshold IOU below which a label is not transferred + :return DataFrame of source RoiSet, including overlaps with this RoiSet and IOU metric """ if self.acc_raw.shape != roiset_from.acc_raw.shape: raise ShapeMismatchError( @@ -584,9 +585,10 @@ class RoiSet(object): roiset_from.get_df(), self.get_df() ) + df_overlaps['transfer'] = df_overlaps.seg_iou > iou_min df_merge = pd.merge( roiset_from.get_df()[columns], - df_overlaps.loc[df_overlaps.seg_iou > iou_min, ['overlaps_with']], + df_overlaps.loc[df_overlaps.transfer, ['overlaps_with']], left_index=True, right_index=True, how='inner', @@ -594,6 +596,8 @@ class RoiSet(object): for col in columns: self.set_classification(col, df_merge[col]) + return df_overlaps + def get_object_class_map(self, name: str) -> InMemoryDataAccessor: """ For a given classification result, return a map where object IDs are replaced by each object's class diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 53c1f042..75a9b775 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -209,7 +209,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue('dummy_class' in roiset1.classification_columns) self.assertFalse('dummy_class' in roiset2.classification_columns) - roiset2.get_instance_classification(roiset1) + res = roiset2.get_instance_classification(roiset1) self.assertTrue('dummy_class' in roiset2.classification_columns) self.assertLess( roiset2.get_df().classify_by_dummy_class.count(), -- GitLab