From ce711cba1966178fa4f9ed19cf34db1ebb07248c Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 3 Feb 2024 21:53:38 +0100
Subject: [PATCH] slice and mask array objects now in RoiSet's dataframe

---
 model_server/extensions/chaeo/products.py     |  3 +-
 .../extensions/chaeo/tests/test_zstack.py     |  8 +--
 model_server/extensions/chaeo/zmask.py        | 59 +++++++++----------
 3 files changed, 32 insertions(+), 38 deletions(-)

diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py
index 82699fb2..84535d73 100644
--- a/model_server/extensions/chaeo/products.py
+++ b/model_server/extensions/chaeo/products.py
@@ -61,9 +61,8 @@ def write_patch_to_file(where, fname, yxcz):
 def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
     patches = []
     for ob in roiset.get_df().itertuples('Roi'):
-        sp_sl = roiset.get_rel_slice_at(ob.Index)
         patch = np.zeros((ob.ebb_h, ob.ebb_w, 1, 1), dtype='uint8')
-        patch[sp_sl][:, :, 0, 0] = roiset.get_mask_at(ob.Index) * 255
+        patch[ob.relative_slice][:, :, 0, 0] = ob.mask * 255
 
         if pad_to:
             patch = pad(patch, pad_to)
diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py
index cc9a4297..6fca51b8 100644
--- a/model_server/extensions/chaeo/tests/test_zstack.py
+++ b/model_server/extensions/chaeo/tests/test_zstack.py
@@ -96,20 +96,18 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
 
     def test_zmask_slices_are_valid(self):
         roiset = self.test_zmask_makes_correct_boxes()
-        slices = [roiset.get_slice_at(i) for i in roiset.get_df().index]
-        for s in slices:
+        for s in roiset.get_slices():
             ebb = roiset.acc_raw.data[s]
             self.assertEqual(len(ebb.shape), 4)
             self.assertTrue(np.all([si >= 1 for si in ebb.shape]))
 
     def test_zmask_rel_slices_are_valid(self):
         roiset = self.test_zmask_makes_correct_boxes()
-        slices = [roiset.get_slice_at(i) for i in roiset.get_df().index]
-        rel_slices = [roiset.get_rel_slice_at(i) for i in roiset.get_df().index]
-        for i, s in enumerate(slices):
+        for i, s in enumerate(roiset.get_slices()):
             ebb = roiset.acc_raw.data[s]
             self.assertEqual(len(ebb.shape), 4)
             self.assertTrue(np.all([si >= 1 for si in ebb.shape]))
+            rel_slices = roiset.get_df()['relative_slice']
             rbb = ebb[rel_slices[i]]
             self.assertEqual(len(rbb.shape), 4)
             self.assertTrue(np.all([si >= 1 for si in rbb.shape]))
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index 2cef8230..b06c276f 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -36,7 +36,9 @@ class RoiSet(object):
         self.acc_raw = acc_raw
 
         self._df = self.filter_df(
-            self.make_df(self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by),
+            self.make_df(
+                self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by
+            ),
             params.filters,
         )
 
@@ -78,7 +80,14 @@ class RoiSet(object):
         self.object_class_map = None
 
     @staticmethod
-    def make_df(acc_raw, acc_obj_ids, expand_box_by):
+    def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
+        """
+        Build dataframe associate object IDs with summary stats
+        :param acc_raw: accessor to raw image data
+        :param acc_obj_ids: accessor to map of object IDs
+        :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
+        :return: pd.DataFrame
+        """
         # build dataframe of objects, assign z index to each object
         argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
         df = (
@@ -117,36 +126,22 @@ class RoiSet(object):
         assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0']))
         assert np.all(df['rel_y1'] <= (df['ebb_x1'] - df['ebb_x0']))
 
+        df['slice'] = df.apply(
+            lambda r:
+            np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1],
+            axis=1
+        )
+        df['relative_slice'] = df.apply(
+            lambda r:
+            np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :],
+            axis=1
+        )
+        df['mask'] = df.apply(
+            lambda r: (acc_obj_ids == r.label)[r.y0: r.y1, r.x0: r.x1],
+            axis=1
+        )
         return df
 
-    def get_slice_at(self, idx) -> tuple:
-        """
-        Return slice object in np.s_ format that defines expanded bounding box of object
-        :param idx: object index (Index in DataFrame, does not necessarily start at zero)
-        :return: slice object
-        """
-        ob = self.get_df().loc[idx, :].astype('int64')
-        return np.s_[ob.ebb_y0: ob.ebb_y1, ob.ebb_x0: ob.ebb_x1, :, ob.ebb_z0: ob.ebb_z1 + 1]
-
-    def get_rel_slice_at(self, idx) -> tuple:
-        """
-        Return slice object in np.s_ format that defines bounding box of an object within its expanded bounding box
-        :param idx: object index (Index in DataFrame, does not necessarily start at zero)
-        :return: slice object
-        """
-        ob = self.get_df().loc[idx, :].astype('int64')
-        return np.s_[ob.rel_y0: ob.rel_y1, ob.rel_x0: ob.rel_x1, :, :]
-
-
-    def get_mask_at(self, idx) -> np.ndarray:
-        """
-        Return 2D array describing object mask that fills (unexpanded) bounding box at index idx
-        :param idx: object index (Index in DataFrame, does not necessarily start at zero)
-        :return: np.ndarray boolean mask
-        """
-        ob = self.get_df().loc[idx, :].astype('int64')
-        obmask = (self.acc_obj_ids == ob.label)
-        return obmask[ob.y0: ob.y1, ob.x0: ob.x1]
 
     @staticmethod
     def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
@@ -158,12 +153,14 @@ class RoiSet(object):
                 vmax = val['max']
                 assert vmin >= 0
                 query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
-        # df.loc[df.query(query_str).index, 'keeper'] = True
         return df.loc[df.query(query_str).index, :]
 
     def get_df(self) -> pd.DataFrame:  # TODO: exclude columns that refer to objects
         return self._df
 
+    def get_slices(self) -> pd.Series:
+        return self.get_df()['slice']
+
     def add_df_col(self, name, se: pd.Series) -> None:
         self._df[name] = se
 
-- 
GitLab