diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 7b1ef9afe2bcbfa79797676a859c8e929c33a5b0..9e1efc4274aa5769dba684ab266841b2ff366842 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -565,11 +565,12 @@ class RoiSet(object): # TODO: typecheck RoiSet not recognized - def get_instance_classification(self, roiset_from, iou_min: float = 0.5): + def get_instance_classification(self, roiset_from, iou_min: float = 0.5) -> pd.DataFrame: """ Transfer instance classification labels from another RoiSet based on intersection over union (IOU) similarity :param roiset_from: RoiSet source of classification labels, same shape as this RoiSet :param iou_min: threshold IOU below which a label is not transferred + :return DataFrame of source RoiSet, including overlaps with this RoiSet and IOU metric """ if self.acc_raw.shape != roiset_from.acc_raw.shape: raise ShapeMismatchError( @@ -584,9 +585,10 @@ class RoiSet(object): roiset_from.get_df(), self.get_df() ) + df_overlaps['transfer'] = df_overlaps.seg_iou > iou_min df_merge = pd.merge( roiset_from.get_df()[columns], - df_overlaps.loc[df_overlaps.seg_iou > iou_min, ['overlaps_with']], + df_overlaps.loc[df_overlaps.transfer, ['overlaps_with']], left_index=True, right_index=True, how='inner', @@ -594,6 +596,8 @@ class RoiSet(object): for col in columns: self.set_classification(col, df_merge[col]) + return df_overlaps + def get_object_class_map(self, name: str) -> InMemoryDataAccessor: """ For a given classification result, return a map where object IDs are replaced by each object's class diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 53c1f04277495ffe8a6ef3087187c4e68585c36b..75a9b7752984998f081fd77ecbf8589efb9304e3 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -209,7 +209,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue('dummy_class' in roiset1.classification_columns) self.assertFalse('dummy_class' in roiset2.classification_columns) - roiset2.get_instance_classification(roiset1) + res = roiset2.get_instance_classification(roiset1) self.assertTrue('dummy_class' in roiset2.classification_columns) self.assertLess( roiset2.get_df().classify_by_dummy_class.count(),