diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 24b9dd6c461c9d36d864017b29df9ad878c0e69b..f48291123fa48e1f0c6abf8f74381cbb7c7de17b 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -165,7 +165,7 @@ def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Da
                 intersec.append(isc)
 
     sdf = df1.iloc[first]
-    sdf.loc[:, 'bbox_overlaps_with'] = second
+    sdf.loc[:, 'overlaps_with'] = second
     sdf.loc[:, 'bbox_intersec'] = intersec
     return sdf
 
@@ -188,9 +188,9 @@ def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Dat
     def _overlap_seg(r):
         roi1 = df1.loc[r.name]
         if df2 is not None:
-            roi2 = df2.loc[r.bbox_overlaps_with]
+            roi2 = df2.loc[r.overlaps_with]
         else:
-            roi2 = df1.loc[r.bbox_overlaps_with]
+            roi2 = df1.loc[r.overlaps_with]
         ex0 = min(roi1.x0, roi2.x0, roi1.x1, roi2.x1)
         ew = max(roi1.x0, roi2.x0, roi1.x1, roi2.x1) - ex0
         ey0 = min(roi1.y0, roi2.y0, roi1.y1, roi2.y1)
diff --git a/model_server/base/util.py b/model_server/base/util.py
index 9b95edca73fbef1e831d4e423e7a30bc82ad5907..5837576172be43a7f80cc57008b018c3f4a15e49 100644
--- a/model_server/base/util.py
+++ b/model_server/base/util.py
@@ -8,6 +8,7 @@ import pandas as pd
 
 from .accessors import InMemoryDataAccessor, write_accessor_data_to_file
 from .models import Model
+from .roiset import filter_df_overlap_seg, RoiSet
 
 def autonumber_new_directory(where: str, prefix: str) -> str:
     """
@@ -164,3 +165,39 @@ def loop_workflow(
 
     if len(failures) > 0:
         pd.DataFrame(failures).to_csv(Path(output_folder_path) / 'failures.csv')
+
+def transfer_classification(
+        r1: RoiSet,
+        r2: RoiSet,
+        iou_min: float = 0.5,
+):
+
+    if r1.acc_raw.shape != r2.acc_raw.shape:
+        raise RoiSetShapeMismatchError(f'Expecting two RoiSets of same shape: {r1.acc_raw.shape} != {r2.acc_raw.shape}')
+
+    classes = r1.classification_columns
+    class_columns = [f'classify_by_{c}' for c in r1.classification_columns]
+
+    df_overlaps = filter_df_overlap_seg(r1.get_df(), r2.get_df())
+    df_merge = pd.merge(
+        r1.get_df()[class_columns],
+        df_overlaps.loc[df_overlaps.seg_iou > iou_min, ['overlaps_with']],
+        left_index=True,
+        right_index=True,
+        how='inner',
+        # suffixes=['r1', 'r2']
+    )
+
+    print('hi')
+    # for cl in classes:
+    #     se = pd.Series(dtype='Int64', index=r2.get_df().index)
+    #     se.loc[:] = None
+    #     se.loc[df_keep.]
+    #
+    # r1.loc[df]
+
+class Error(Exception):
+    pass
+
+class RoiSetShapeMismatchError(Error):
+    pass
\ No newline at end of file
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 887ad06977bad7d0cc165b5729be084892aecfb0..f36d1dc8bcc4c56d85ef41d1ec3dffa3f1230a2d 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -11,6 +11,7 @@ from model_server.base.roiset import RoiSet
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
 from model_server.base.models import DummyInstanceSegmentationModel
 from model_server.base.process import smooth
+from model_server.base.util import transfer_classification
 import model_server.conf.testing as conf
 
 data = conf.meta['image_files']
@@ -195,10 +196,23 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def test_transfer_classification(self):
         roiset1 = RoiSet.from_binary_mask(self.stack, self.seg_mask)
-        smoothed_mask = self.seg_mask.apply(lambda x: smooth(x, sig=3.0))
+
+        # prepare alternative mask and compare
+        smoothed_mask = self.seg_mask.apply(lambda x: smooth(x, sig=1.5))
         roiset2 = RoiSet.from_binary_mask(self.stack, smoothed_mask)
+        dmask = (self.seg_mask.data / 255) + (smoothed_mask.data / 255)
+        self.assertTrue(np.all(np.unique(dmask) == [0, 1, 2]))
+        total_iou = (dmask == 2).sum() / ((dmask == 1).sum() + (dmask == 2).sum())
+        self.assertGreater(total_iou, 0.6)
+
+        # classify first RoiSet
+        roiset1.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
+
+        self.assertTrue('dummy_class' in roiset1.classification_columns)
+        transfer_classification(roiset1, roiset2)
         self.assertTrue(False)
 
+
     def test_classify_by_with_derived_channel(self):
         class ModelWithDerivedInputs(DummyInstanceSegmentationModel):
             def infer(self, img, mask):
@@ -648,8 +662,8 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
         res = filter_df_overlap_bbox(df)
         self.assertEqual(len(res), 4)
-        self.assertTrue((res.loc[0, 'bbox_overlaps_with'] == [1]).all())
-        self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0, 2]).all())
+        self.assertTrue((res.loc[0, 'overlaps_with'] == [1]).all())
+        self.assertTrue((res.loc[1, 'overlaps_with'] == [0, 2]).all())
         self.assertTrue((res.bbox_intersec == 2).all())
         return res
 
@@ -670,7 +684,7 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
             'zi': [0],
         })
         res = filter_df_overlap_bbox(df1, df2)
-        self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0]).all())
+        self.assertTrue((res.loc[1, 'overlaps_with'] == [0]).all())
         self.assertEqual(len(res), 1)
         self.assertTrue((res.bbox_intersec == 2).all())
 
@@ -734,7 +748,7 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
             ]
         })
         res = filter_df_overlap_seg(df1, df2)
-        self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0]).all())
+        self.assertTrue((res.loc[1, 'overlaps_with'] == [0]).all())
         self.assertEqual(len(res), 1)
         self.assertTrue((res.bbox_intersec == 2).all())
         self.assertTrue((res.loc[res.seg_overlaps, :].index == [1]).all())