From d436bb0d702c66b08ff7711cb8623c9bb40ca14a Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 3 Feb 2024 16:44:15 +0100
Subject: [PATCH] Moved expanded and relative bounding box calcs in to DF
 creation

---
 model_server/extensions/chaeo/products.py |  6 +-
 model_server/extensions/chaeo/zmask.py    | 87 ++++++++++++++++++++---
 2 files changed, 80 insertions(+), 13 deletions(-)

diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py
index d68b4c36..c9acc1fd 100644
--- a/model_server/extensions/chaeo/products.py
+++ b/model_server/extensions/chaeo/products.py
@@ -63,7 +63,7 @@ def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
     for mi in roiset.zmask_meta:
         sl = mi['slice']
 
-        rbb = mi['relative_bounding_box']
+        rbb = mi['relative_bounding_box'] # TODO: call DF rbb fields
         x0 = rbb['x0']
         y0 = rbb['y0']
         x1 = rbb['x1']
@@ -111,7 +111,7 @@ def get_patches_from_zmask_meta(
     for mi in zmask_meta:
 
         sl = mi['slice']
-        rbb = mi['relative_bounding_box']
+        rbb = mi['relative_bounding_box'] # TODO: call rel_ fields in DF
         idx = mi['df_index']
 
         x0 = rbb['x0']
@@ -267,7 +267,7 @@ def export_3d_patches_with_focus_metrics(
     for mi in zmask_meta:
         obj = mi['info']
         sl = mi['slice']
-        rbb = mi['relative_bounding_box']
+        rbb = mi['relative_bounding_box'] # TODO: use rel_ fields in DF
         idx = mi['df_index']
 
         patch = stack.data[sl]
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index 37996a0d..7e852e6d 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -67,25 +67,66 @@ class RoiSet(object):
         self.acc_obj_ids = acc_obj_ids
         self.acc_raw = acc_raw
 
+
+
         self._df = self.filter_df(
-            self.make_df(self.acc_raw, self.acc_obj_ids),
+            self.make_df(self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by),
             params.filters,
         )
 
+        # self._df = self.make_slices(
+        #     self._df,
+        #     expand_box_by=params.expand_box_by,
+        #     shape=acc_raw.shape
+        # )
+
         # remaining zmask_meta write ops
-        self.zmask_meta, _, self.interm = build_zmask_from_object_mask(
-            acc_obj_ids,
-            acc_raw,
-            self.get_df(),
-            params=params,
-        )
+        # self.zmask_meta, _, self.interm = build_zmask_from_object_mask(
+        #     acc_obj_ids,
+        #     acc_raw,
+        #     self.get_df(),
+        #     params=params,
+        # )
+
+        # temporarily build zmask meta here
+        meta = []
+        for ob in self.get_df().itertuples(name='LabeledObject'):
+            sl = np.s_[ob.ebb_y0: ob.ebb_y1, ob.ebb_x0: ob.ebb_x1, :, ob.ebb_z0: ob.ebb_z1 + 1]  # TODO: on-the-fly in RoiSet, given DF
+
+            # compute contours
+            obmask = (acc_obj_ids == ob.label)  # TODO: on-the-fly
+            contour = find_contours(obmask)  # TODO: on-the-fly
+            mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1]
+
+            rbb = {  # TODO: just put in the DF
+                'y0': ob.rel_y0,
+                'y1': ob.rel_y1,
+                'x0': ob.rel_x0,
+                'x1': ob.rel_x1,
+            }
+
+            meta.append({
+                'df_index': ob.Index,
+                'info': ob,
+                'slice': sl,
+                'relative_bounding_box': rbb,  # TODO: put in DF
+                'contour': contour,  # TODO: delegate to getter
+                'mask': mask  # TODO: delegate to getter
+            })
+        self.zmask_meta = meta
+
+        # return intermediate image arrays  # TODO: make on-the-fly
+        self.interm = {
+            'label_map': acc_obj_ids,
+            'argmax': acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'),
+        }
 
         self.count = len(self.zmask_meta)
         self.object_id_labels = self.interm['label_map']
         self.object_class_map = None
 
     @staticmethod
-    def make_df(acc_raw, acc_obj_ids):
+    def make_df(acc_raw, acc_obj_ids, expand_box_by):
         # build dataframe of objects, assign z index to each object
         argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
         df = (
@@ -101,8 +142,34 @@ class RoiSet(object):
             )
         )
         df['zi'] = df['intensity_mean'].round().astype('int')
+
+        # compute expanded bounding boxes
+        h, w, c, nz = acc_raw.shape
+        ebxy, ebz = expand_box_by
+        df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0))
+        df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h))
+        df['ebb_x0'] = (df.x0 - ebxy).apply(lambda x: max(x, 0))
+        df['ebb_x1'] = (df.x1 + ebxy).apply(lambda x: min(x, w))
+        df['ebb_z0'] = (df.zi - ebz).apply(lambda x: max(x, 0))
+        df['ebb_z1'] = (df.zi + ebz).apply(lambda x: min(x, nz))
+
+        # compute relative bounding boxes
+        df['rel_y0'] = df.y0 - df.ebb_y0
+        df['rel_y1'] = df.y1 - df.ebb_y1
+        df['rel_x0'] = df.x0 - df.ebb_x0
+        df['rel_x1'] = df.x1 - df.ebb_x1
+
+        assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0']))
+        assert np.all(df['rel_y1'] <= (df['ebb_x1'] - df['ebb_x0']))
+
         return df
 
+    # def get_slices(self):  # TODO: actually map to DF index as new column
+    #     sl = []
+    #     for ob in self.get_df().itertuples(name='LabeledObject'):
+    #         sl.append(np.s_[ob.ebb_y0: ob.ebb_y1, ob.ebb_x0: ob.ebb_x1, :, ob.ebb_z0: ob.ebb_z1 + 1])
+    #     return sl
+
     @staticmethod
     def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
         query_str = 'label > 0'  # always true
@@ -116,7 +183,7 @@ class RoiSet(object):
         # df.loc[df.query(query_str).index, 'keeper'] = True
         return df.loc[df.query(query_str).index, :]
 
-    def get_df(self) -> pd.DataFrame:
+    def get_df(self) -> pd.DataFrame:  # TODO: exclude columns that refer to objects
         return self._df
 
     def add_df_col(self, name, se: pd.Series) -> None:
@@ -344,7 +411,7 @@ def build_zmask_from_object_mask(
         meta.append({
             'df_index': ob.Index,
             'info': ob,
-            'slice': sl,
+            # 'slice': sl,
             'relative_bounding_box': rbb, # TODO: put in DF
             'contour': contour, # TODO: delegate to getter
             'mask': mask # TODO: delegate to getter
-- 
GitLab