diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 24b9dd6c461c9d36d864017b29df9ad878c0e69b..f48291123fa48e1f0c6abf8f74381cbb7c7de17b 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -165,7 +165,7 @@ def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Da intersec.append(isc) sdf = df1.iloc[first] - sdf.loc[:, 'bbox_overlaps_with'] = second + sdf.loc[:, 'overlaps_with'] = second sdf.loc[:, 'bbox_intersec'] = intersec return sdf @@ -188,9 +188,9 @@ def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Dat def _overlap_seg(r): roi1 = df1.loc[r.name] if df2 is not None: - roi2 = df2.loc[r.bbox_overlaps_with] + roi2 = df2.loc[r.overlaps_with] else: - roi2 = df1.loc[r.bbox_overlaps_with] + roi2 = df1.loc[r.overlaps_with] ex0 = min(roi1.x0, roi2.x0, roi1.x1, roi2.x1) ew = max(roi1.x0, roi2.x0, roi1.x1, roi2.x1) - ex0 ey0 = min(roi1.y0, roi2.y0, roi1.y1, roi2.y1) diff --git a/model_server/base/util.py b/model_server/base/util.py index 9b95edca73fbef1e831d4e423e7a30bc82ad5907..5837576172be43a7f80cc57008b018c3f4a15e49 100644 --- a/model_server/base/util.py +++ b/model_server/base/util.py @@ -8,6 +8,7 @@ import pandas as pd from .accessors import InMemoryDataAccessor, write_accessor_data_to_file from .models import Model +from .roiset import filter_df_overlap_seg, RoiSet def autonumber_new_directory(where: str, prefix: str) -> str: """ @@ -164,3 +165,39 @@ def loop_workflow( if len(failures) > 0: pd.DataFrame(failures).to_csv(Path(output_folder_path) / 'failures.csv') + +def transfer_classification( + r1: RoiSet, + r2: RoiSet, + iou_min: float = 0.5, +): + + if r1.acc_raw.shape != r2.acc_raw.shape: + raise RoiSetShapeMismatchError(f'Expecting two RoiSets of same shape: {r1.acc_raw.shape} != {r2.acc_raw.shape}') + + classes = r1.classification_columns + class_columns = [f'classify_by_{c}' for c in r1.classification_columns] + + df_overlaps = filter_df_overlap_seg(r1.get_df(), r2.get_df()) + df_merge = pd.merge( + r1.get_df()[class_columns], + df_overlaps.loc[df_overlaps.seg_iou > iou_min, ['overlaps_with']], + left_index=True, + right_index=True, + how='inner', + # suffixes=['r1', 'r2'] + ) + + print('hi') + # for cl in classes: + # se = pd.Series(dtype='Int64', index=r2.get_df().index) + # se.loc[:] = None + # se.loc[df_keep.] + # + # r1.loc[df] + +class Error(Exception): + pass + +class RoiSetShapeMismatchError(Error): + pass \ No newline at end of file diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 887ad06977bad7d0cc165b5729be084892aecfb0..f36d1dc8bcc4c56d85ef41d1ec3dffa3f1230a2d 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -11,6 +11,7 @@ from model_server.base.roiset import RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack from model_server.base.models import DummyInstanceSegmentationModel from model_server.base.process import smooth +from model_server.base.util import transfer_classification import model_server.conf.testing as conf data = conf.meta['image_files'] @@ -195,10 +196,23 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_transfer_classification(self): roiset1 = RoiSet.from_binary_mask(self.stack, self.seg_mask) - smoothed_mask = self.seg_mask.apply(lambda x: smooth(x, sig=3.0)) + + # prepare alternative mask and compare + smoothed_mask = self.seg_mask.apply(lambda x: smooth(x, sig=1.5)) roiset2 = RoiSet.from_binary_mask(self.stack, smoothed_mask) + dmask = (self.seg_mask.data / 255) + (smoothed_mask.data / 255) + self.assertTrue(np.all(np.unique(dmask) == [0, 1, 2])) + total_iou = (dmask == 2).sum() / ((dmask == 1).sum() + (dmask == 2).sum()) + self.assertGreater(total_iou, 0.6) + + # classify first RoiSet + roiset1.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) + + self.assertTrue('dummy_class' in roiset1.classification_columns) + transfer_classification(roiset1, roiset2) self.assertTrue(False) + def test_classify_by_with_derived_channel(self): class ModelWithDerivedInputs(DummyInstanceSegmentationModel): def infer(self, img, mask): @@ -648,8 +662,8 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): res = filter_df_overlap_bbox(df) self.assertEqual(len(res), 4) - self.assertTrue((res.loc[0, 'bbox_overlaps_with'] == [1]).all()) - self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0, 2]).all()) + self.assertTrue((res.loc[0, 'overlaps_with'] == [1]).all()) + self.assertTrue((res.loc[1, 'overlaps_with'] == [0, 2]).all()) self.assertTrue((res.bbox_intersec == 2).all()) return res @@ -670,7 +684,7 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): 'zi': [0], }) res = filter_df_overlap_bbox(df1, df2) - self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0]).all()) + self.assertTrue((res.loc[1, 'overlaps_with'] == [0]).all()) self.assertEqual(len(res), 1) self.assertTrue((res.bbox_intersec == 2).all()) @@ -734,7 +748,7 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): ] }) res = filter_df_overlap_seg(df1, df2) - self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0]).all()) + self.assertTrue((res.loc[1, 'overlaps_with'] == [0]).all()) self.assertEqual(len(res), 1) self.assertTrue((res.bbox_intersec == 2).all()) self.assertTrue((res.loc[res.seg_overlaps, :].index == [1]).all())