From 1819a0f62ffa706a54bdaf9359447369e3da249d Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 2 Nov 2024 09:08:46 +0100 Subject: [PATCH] Removed aggregation method, instead can filter by other classification results --- model_server/base/roiset.py | 29 +++++++---------------------- tests/base/test_roiset.py | 20 ++++++-------------- 2 files changed, 13 insertions(+), 36 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 925b3b7b..f7cc75f2 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -618,25 +618,6 @@ class RoiSet(object): se[roi.Index] = oc self.set_classification(f'classify_by_{name}', se) - def aggregate_classifications(self, query: str, name: str = 'aggregate'): - """ - Run query on DataFrame and put the results in a new boolean column - """ - cname = 'classify_by_' + name - if cname in self._df.columns: - raise DataFrameQueryError(f'Name {cname} is already used in RoiSet dataframe') - - if self.count == 0: - self._df[cname] = None - return True - - try: - self.set_classification( - cname, - self._df.eval(query) - ) - except Exception as e: - raise DataFrameQueryError(e) def get_instance_classification(self, roiset_from: Self, iou_min: float = 0.5) -> pd.DataFrame: """ @@ -671,10 +652,11 @@ class RoiSet(object): return df_overlaps - def get_object_class_map(self, name: str) -> InMemoryDataAccessor: + def get_object_class_map(self, name: str, filter_by: Union[List, None] = None) -> InMemoryDataAccessor: """ For a given classification result, return a map where object IDs are replaced by each object's class :param name: name of the classification result, same as passed to RoiSet.classify_by() + :param filter_by: only include ROIs if the intersection of all specified classifications is True :return: accessor of object class map """ colname = ('classify_by_' + name) @@ -684,8 +666,11 @@ class RoiSet(object): def _label_object_class(roi): om[self.acc_obj_ids.data == roi.label] = roi[colname] - self._df.apply(_label_object_class, axis=1) - + if filter_by is None: + self._df.apply(_label_object_class, axis=1) + else: + pd_fil = self._df[[f'classify_by_{fb}' for fb in filter_by]] + self._df.loc[pd_fil.all(axis=1), :].apply(_label_object_class, axis=1) return InMemoryDataAccessor(om) def get_serializable_dataframe(self) -> pd.DataFrame: diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 5f8fe52b..ee73aa4b 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -857,18 +857,10 @@ class TestIntensityThresholdObjectModel(BaseTestRoiSetMonoProducts, unittest.Tes def test_aggregate_classification_results(self): roiset = self.test_roiset_with_instance_segmentation() - roiset.aggregate_classifications( - query='classify_by_permissive_model == 1 & classify_by_avg_intensity == 1', - name='aggregation' - ) - self.assertIn('aggregation', roiset.classification_columns) - self.assertTrue(np.all(roiset.get_object_class_map('aggregation').unique()[0] == [0, 1])) - self.assertEqual( - roiset.get_object_class_map('aggregation').data.sum(), - roiset.get_object_class_map('avg_intensity').data.sum(), - ) - self.assertGreater( - roiset.get_object_class_map('permissive_model').data.sum(), - roiset.get_object_class_map('aggregation').data.sum(), - ) + om_mod = roiset.get_object_class_map('permissive_model') + om_tr = roiset.get_object_class_map('avg_intensity') + om_fil = roiset.get_object_class_map('permissive_model', filter_by=['avg_intensity']) + self.assertTrue(np.all(om_fil.unique()[0] == [0, 1])) + self.assertEqual(om_fil.data.sum(), om_tr.data.sum()) + self.assertGreater(om_mod.data.sum(), om_fil.data.sum()) -- GitLab