From 881f1fde6edb6e2c4119c858c1232f5dab6bd494 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 30 Oct 2024 12:06:08 +0100
Subject: [PATCH] Added and covered PatchStack dataframe with summary stats

---
 model_server/base/accessors.py | 16 +++++++++++-----
 tests/base/test_accessors.py   | 15 ++++++++++++++-
 2 files changed, 25 insertions(+), 6 deletions(-)

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index 35f6575d..e48103eb 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -492,16 +492,22 @@ class PatchStack(InMemoryDataAccessor):
     def shape_dict(self):
         return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
 
-    @property
-    def df(self):
+    def get_object_df(self, mask) -> pd.DataFrame:
+        """
+        Given a mask patch stack of the same size, return a DataFrame summarizing the area and intensity of objects,
+        assuming the each patch in the patch stack represents a single object.
+        :param mask of the same dimensions
+        """
+        if self.shape != mask.shape or not mask.is_mask():
+            raise DataShapeError(f'Patch stack object dataframe expects a mask of the same dimensions')
         df = pd.DataFrame([
             {
                 'label': i,
-                'area': (self.iat(i).data > 0).sum(),
-                'sum': self.iat(i).data.sum()
+                'area': (mask.iat(i).data > 0).sum(),
+                'intensity_sum': (self.iat(i).data * (mask.iat(i).data > 0)).sum()
             } for i in range(0, self.count)
         ])
-        df['intensity_mean'] = df['sum'] / df['area']
+        df['intensity_mean'] = df['intensity_sum'] / df['area']
         return df
 
 
diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py
index 31d32d7d..05663dd7 100644
--- a/tests/base/test_accessors.py
+++ b/tests/base/test_accessors.py
@@ -303,8 +303,21 @@ class TestPatchStackAccessor(unittest.TestCase):
         self.assertEqual(acc.get_mono(channel=0).data_yxz.shape, (n, h, w, nz))
         self.assertEqual(acc.get_mono(channel=0, mip=True).data_yx.shape, (n, h, w))
 
+    def test_object_df(self):
+        w = 30
+        h = 20
+        n = 2
+        nz = 3
+        nc = 3
+        acc = PatchStack(_random_int(n, h, w, nc, nz))
+        mask_data= np.zeros((n, h, w, nc, nz), dtype='uint8')
+        mask_data[0, 0:5, 0:5, :, :] = 255
+        mask_data[1, 0:10, 0:10, :, :] = 255
+        mask = PatchStack(mask_data)
+        df = acc.get_object_df(mask)
         # intensity means are centered around half of full range
-        self.assertTrue(np.all(((acc.df['intensity_mean'] / acc.dtype_max) - 0.5)**2 < 1e-3))
+        self.assertTrue(np.all(((df['intensity_mean'] / acc.dtype_max) - 0.5)**2 < 1e-3))
+        self.assertTrue(df['area'][1] / df['area'][0] == 4.0)
         return acc
 
     def test_get_one_channel(self):
-- 
GitLab