From 53ef598ac1fab09035b46631bf60b57439f387a9 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 13 Sep 2024 10:08:06 +0200 Subject: [PATCH] Object classification transfer is functional --- model_server/base/roiset.py | 45 ++++++++++++++++++++++++++++++++----- model_server/base/util.py | 38 +------------------------------ tests/base/test_roiset.py | 10 ++++++--- 3 files changed, 48 insertions(+), 45 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index f4829112..7b1ef9af 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -561,9 +561,39 @@ class RoiSet(object): ) )[-1] se[roi.Index] = oc - self.set_classification(name, se) + self.set_classification(f'classify_by_{name}', se) + # TODO: typecheck RoiSet not recognized + def get_instance_classification(self, roiset_from, iou_min: float = 0.5): + """ + Transfer instance classification labels from another RoiSet based on intersection over union (IOU) similarity + :param roiset_from: RoiSet source of classification labels, same shape as this RoiSet + :param iou_min: threshold IOU below which a label is not transferred + """ + if self.acc_raw.shape != roiset_from.acc_raw.shape: + raise ShapeMismatchError( + f'Expecting two RoiSets of same shape: {self.acc_raw.shape} != {roiset_from.acc_raw.shape}') + + columns = [f'classify_by_{c}' for c in roiset_from.classification_columns] + + if len(columns) == 0: + raise MissingInstanceLabelsError('Expecting at least on instance classification channel but none found') + + df_overlaps = filter_df_overlap_seg( + roiset_from.get_df(), + self.get_df() + ) + df_merge = pd.merge( + roiset_from.get_df()[columns], + df_overlaps.loc[df_overlaps.seg_iou > iou_min, ['overlaps_with']], + left_index=True, + right_index=True, + how='inner', + ).set_index('overlaps_with') + for col in columns: + self.set_classification(col, df_merge[col]) + def get_object_class_map(self, name: str) -> InMemoryDataAccessor: """ For a given classification result, return a map where object IDs are replaced by each object's class @@ -781,14 +811,13 @@ class RoiSet(object): pr = 'classify_by_' return [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)] - def set_classification(self, cname: str, se: pd.Series): + def set_classification(self, colname: str, se: pd.Series): """ Set instance classification result as a column addition on dataframe - :param cname: name of classification result + :param colname: name of classification result :param se: series containing class information """ - col = f'classify_by_{cname}' - self._df[col] = se + self._df[colname] = se def run_exports(self, where: Path, channel, prefix, params: RoiSetExportParams) -> dict: """ @@ -952,4 +981,10 @@ class MissingSegmentationError(Error): pass class PatchMaskShapeError(Error): + pass + +class ShapeMismatchError(Error): + pass + +class MissingInstanceLabelsError(Error): pass \ No newline at end of file diff --git a/model_server/base/util.py b/model_server/base/util.py index 58375761..81736d48 100644 --- a/model_server/base/util.py +++ b/model_server/base/util.py @@ -164,40 +164,4 @@ def loop_workflow( ) if len(failures) > 0: - pd.DataFrame(failures).to_csv(Path(output_folder_path) / 'failures.csv') - -def transfer_classification( - r1: RoiSet, - r2: RoiSet, - iou_min: float = 0.5, -): - - if r1.acc_raw.shape != r2.acc_raw.shape: - raise RoiSetShapeMismatchError(f'Expecting two RoiSets of same shape: {r1.acc_raw.shape} != {r2.acc_raw.shape}') - - classes = r1.classification_columns - class_columns = [f'classify_by_{c}' for c in r1.classification_columns] - - df_overlaps = filter_df_overlap_seg(r1.get_df(), r2.get_df()) - df_merge = pd.merge( - r1.get_df()[class_columns], - df_overlaps.loc[df_overlaps.seg_iou > iou_min, ['overlaps_with']], - left_index=True, - right_index=True, - how='inner', - # suffixes=['r1', 'r2'] - ) - - print('hi') - # for cl in classes: - # se = pd.Series(dtype='Int64', index=r2.get_df().index) - # se.loc[:] = None - # se.loc[df_keep.] - # - # r1.loc[df] - -class Error(Exception): - pass - -class RoiSetShapeMismatchError(Error): - pass \ No newline at end of file + pd.DataFrame(failures).to_csv(Path(output_folder_path) / 'failures.csv') \ No newline at end of file diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index f36d1dc8..53c1f042 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -11,7 +11,6 @@ from model_server.base.roiset import RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack from model_server.base.models import DummyInstanceSegmentationModel from model_server.base.process import smooth -from model_server.base.util import transfer_classification import model_server.conf.testing as conf data = conf.meta['image_files'] @@ -209,8 +208,13 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): roiset1.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) self.assertTrue('dummy_class' in roiset1.classification_columns) - transfer_classification(roiset1, roiset2) - self.assertTrue(False) + self.assertFalse('dummy_class' in roiset2.classification_columns) + roiset2.get_instance_classification(roiset1) + self.assertTrue('dummy_class' in roiset2.classification_columns) + self.assertLess( + roiset2.get_df().classify_by_dummy_class.count(), + roiset1.get_df().classify_by_dummy_class.count(), + ) def test_classify_by_with_derived_channel(self): -- GitLab