From 1819a0f62ffa706a54bdaf9359447369e3da249d Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 2 Nov 2024 09:08:46 +0100
Subject: [PATCH] Removed aggregation method, instead can filter by other
 classification results

---
 model_server/base/roiset.py | 29 +++++++----------------------
 tests/base/test_roiset.py   | 20 ++++++--------------
 2 files changed, 13 insertions(+), 36 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 925b3b7b..f7cc75f2 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -618,25 +618,6 @@ class RoiSet(object):
             se[roi.Index] = oc
         self.set_classification(f'classify_by_{name}', se)
 
-    def aggregate_classifications(self, query: str, name: str = 'aggregate'):
-        """
-        Run query on DataFrame and put the results in a new boolean column
-        """
-        cname = 'classify_by_' + name
-        if cname in self._df.columns:
-            raise DataFrameQueryError(f'Name {cname} is already used in RoiSet dataframe')
-
-        if self.count == 0:
-            self._df[cname] = None
-            return True
-
-        try:
-            self.set_classification(
-                cname,
-                self._df.eval(query)
-            )
-        except Exception as e:
-            raise DataFrameQueryError(e)
 
     def get_instance_classification(self, roiset_from: Self, iou_min: float = 0.5) -> pd.DataFrame:
         """
@@ -671,10 +652,11 @@ class RoiSet(object):
 
         return df_overlaps
 
-    def get_object_class_map(self, name: str) -> InMemoryDataAccessor:
+    def get_object_class_map(self, name: str, filter_by: Union[List, None] = None) -> InMemoryDataAccessor:
         """
         For a given classification result, return a map where object IDs are replaced by each object's class
         :param name: name of the classification result, same as passed to RoiSet.classify_by()
+        :param filter_by: only include ROIs if the intersection of all specified classifications is True
         :return: accessor of object class map
         """
         colname = ('classify_by_' + name)
@@ -684,8 +666,11 @@ class RoiSet(object):
 
         def _label_object_class(roi):
             om[self.acc_obj_ids.data == roi.label] = roi[colname]
-        self._df.apply(_label_object_class, axis=1)
-
+        if filter_by is None:
+            self._df.apply(_label_object_class, axis=1)
+        else:
+            pd_fil = self._df[[f'classify_by_{fb}' for fb in filter_by]]
+            self._df.loc[pd_fil.all(axis=1), :].apply(_label_object_class, axis=1)
         return InMemoryDataAccessor(om)
 
     def get_serializable_dataframe(self) -> pd.DataFrame:
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 5f8fe52b..ee73aa4b 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -857,18 +857,10 @@ class TestIntensityThresholdObjectModel(BaseTestRoiSetMonoProducts, unittest.Tes
 
     def test_aggregate_classification_results(self):
         roiset = self.test_roiset_with_instance_segmentation()
-        roiset.aggregate_classifications(
-            query='classify_by_permissive_model == 1 & classify_by_avg_intensity == 1',
-            name='aggregation'
-        )
-        self.assertIn('aggregation', roiset.classification_columns)
-        self.assertTrue(np.all(roiset.get_object_class_map('aggregation').unique()[0] == [0, 1]))
-        self.assertEqual(
-            roiset.get_object_class_map('aggregation').data.sum(),
-            roiset.get_object_class_map('avg_intensity').data.sum(),
-        )
-        self.assertGreater(
-            roiset.get_object_class_map('permissive_model').data.sum(),
-            roiset.get_object_class_map('aggregation').data.sum(),
-        )
+        om_mod = roiset.get_object_class_map('permissive_model')
+        om_tr = roiset.get_object_class_map('avg_intensity')
+        om_fil = roiset.get_object_class_map('permissive_model', filter_by=['avg_intensity'])
+        self.assertTrue(np.all(om_fil.unique()[0] == [0, 1]))
+        self.assertEqual(om_fil.data.sum(), om_tr.data.sum())
+        self.assertGreater(om_mod.data.sum(), om_fil.data.sum())
 
-- 
GitLab