From 3e8ca3db67c43fff1e6f69cc90bccb60a9734ee5 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 3 Feb 2024 15:48:14 +0100
Subject: [PATCH] Moved dataframe logic outside of constructor; losing
 alignment of patch stacks df

---
 .../extensions/chaeo/tests/test_zstack.py     |  10 +-
 model_server/extensions/chaeo/zmask.py        | 120 ++++++++++++------
 2 files changed, 88 insertions(+), 42 deletions(-)

diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py
index a925a503..aa4deb86 100644
--- a/model_server/extensions/chaeo/tests/test_zstack.py
+++ b/model_server/extensions/chaeo/tests/test_zstack.py
@@ -123,16 +123,16 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
         id_map = get_label_ids(self.seg_mask)
 
         roiset = RoiSet(id_map, self.stack_ch_pa, params=RoiSetMetaParams(mask_type='boxes'))
-        df = roiset.df
+        df = roiset.get_df(filters=None)
 
         from model_server.extensions.chaeo.zmask import project_stack_from_focal_points
 
-        dff = df[df['keeper']]
+        # dff = df[df['keeper']]
 
         img = project_stack_from_focal_points(
-            dff['centroid-0'].to_numpy(),
-            dff['centroid-1'].to_numpy(),
-            dff['zi'].to_numpy(),
+            df['centroid-0'].to_numpy(),
+            df['centroid-1'].to_numpy(),
+            df['zi'].to_numpy(),
             self.stack,
             degree=4,
         )
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index 587c801b..9034e1ce 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -15,7 +15,7 @@ from model_server.base.process import pad, rescale, resample_to_8bit
 
 from model_server.extensions.chaeo.annotators import draw_boxes_on_3d_image
 from model_server.extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack
-from extensions.chaeo.params import RoiSetMetaParams, RoiSetExportParams
+from extensions.chaeo.params import RoiFilter, RoiSetMetaParams, RoiSetExportParams
 from model_server.extensions.chaeo.accessors import MonoPatchStack
 from model_server.extensions.chaeo.process import mask_largest_object
 from model_server.extensions.chaeo.products import get_patches_from_zmask_meta, get_patch_masks, export_patch_masks_from_zstack
@@ -32,17 +32,57 @@ class RoiSet(object):
             acc_raw: GenericImageDataAccessor,
             params: RoiSetMetaParams = RoiSetMetaParams(),
     ):
-        # parse filters
-        filters = params.filters
-        query_str = 'label > 0'  # always true
-        if filters is not None:
-            for k, val in filters.dict(exclude_unset=True).items():
-                assert k in ('area', 'solidity')
-                vmin = val['min']
-                vmax = val['max']
-                assert vmin >= 0
-                query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
+        # # parse filters
+        # filters = params.filters
+        # query_str = 'label > 0'  # always true
+        # if filters is not None:
+        #     for k, val in filters.dict(exclude_unset=True).items():
+        #         assert k in ('area', 'solidity')
+        #         vmin = val['min']
+        #         vmax = val['max']
+        #         assert vmin >= 0
+        #         query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
+        #
+        # # build dataframe of objects, assign z index to each object
+        # argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
+        # df = (
+        #     pd.DataFrame(
+        #         regionprops_table(
+        #             acc_obj_ids,
+        #             intensity_image=argmax,
+        #             properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid')
+        #         )
+        #     )
+        #     .rename(
+        #         columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1',}
+        #     )
+        # )
+        # df['zi'] = df['intensity_mean'].round().astype('int')
+        # df['keeper'] = False
+        # df.loc[df.query(query_str).index, 'keeper'] = True
+        # self.df = df
+
+        # remaining zmask_meta write ops
+
+        self.acc_obj_ids = acc_obj_ids
+        self.acc_raw = acc_raw
+
+        self._df = self.make_df(self.acc_raw, self.acc_obj_ids)
 
+        # remaining zmask_meta write ops
+        self.zmask_meta, _, self.interm = build_zmask_from_object_mask(
+            acc_obj_ids,
+            acc_raw,
+            self.get_df(filters=params.filters),
+            params=params,
+        )
+
+        self.count = len(self.zmask_meta)
+        self.object_id_labels = self.interm['label_map']
+        self.object_class_map = None
+
+    @staticmethod
+    def make_df(acc_raw, acc_obj_ids):
         # build dataframe of objects, assign z index to each object
         argmax = acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16')
         df = (
@@ -54,27 +94,26 @@ class RoiSet(object):
                 )
             )
             .rename(
-                columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1',}
+                columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', }
             )
         )
         df['zi'] = df['intensity_mean'].round().astype('int')
-        df['keeper'] = False
-        df.loc[df.query(query_str).index, 'keeper'] = True
-        self.df = df
+        return df
 
-        # remaining zmask_meta write ops
-        self.zmask_meta, _, self.interm = build_zmask_from_object_mask(
-            acc_obj_ids,
-            acc_raw,
-            df,
-            params=params,
-        )
-        self.acc_obj_ids = acc_obj_ids
-        self.acc_raw = acc_raw
-        self.count = len(self.zmask_meta)
-        self.object_id_labels = self.interm['label_map']
-        self.object_class_map = None
+    def get_df(self, filters: RoiFilter = None) -> pd.DataFrame:
+        query_str = 'label > 0'  # always true
+        if filters is not None:  # parse filters
+            for k, val in filters.dict(exclude_unset=True).items():
+                assert k in ('area', 'solidity')
+                vmin = val['min']
+                vmax = val['max']
+                assert vmin >= 0
+                query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
+        # df.loc[df.query(query_str).index, 'keeper'] = True
+        return self._df.loc[self._df.query(query_str).index, :]
 
+    def add_df_col(self, name, se: pd.Series) -> None:
+        self._df[name] = se
 
     def get_multichannel_projection(self):  # TODO: document and test
         dff = self.df[self.df['keeper']]
@@ -121,7 +160,7 @@ class RoiSet(object):
     def get_slices(self):
         return [zm.slice for zm in self.zmask_meta]
 
-    def get_zmask(self, mask_type='boxes'):
+    def get_zmask(self, mask_type='boxes', filters: RoiFilter = None):
         """
         Return a mask of same dimensionality as raw data
 
@@ -135,7 +174,8 @@ class RoiSet(object):
 
         # make an object map where label is replaced by focus position in stack and background is -1
         lut = np.zeros(lamap.max() + 1) - 1
-        lut[self.df.label] = self.df.zi
+        df = self.get_df(filters=filters)
+        lut[df.label] = df.zi
 
         if mask_type == 'contours':
             zi_map = (lut[lamap] + 1.0).astype('int')
@@ -157,24 +197,30 @@ class RoiSet(object):
 
         return zi_st
 
-    def classify_by(self, channel, object_classification_model: InstanceSegmentationModel):
+    def classify_by(self, channel, object_classification_model: InstanceSegmentationModel, filters: RoiFilter = None):
+        # adds a column to self._df
         # do this on a patch basis, i.e. only one object per frame
         obmap_patches = object_classification_model.label_instance_class(
-            self.get_raw_patches(channel),
+            self.get_raw_patches(channel),  # TODO: enforce df index
             self.get_patch_masks()
         )
 
         lamap = self.object_id_labels
         om = np.zeros(lamap.shape, dtype=lamap.dtype)
-        self.df['instance_class'] = np.nan
+        # self.df['instance_class'] = np.nan
+
+        df = self.get_df(filters=filters)
+        idx = df.index
+        se = pd.Series(data=np.nan, index=idx)
 
         # assign labels to object map:
-        for ii in range(0, self.count):
-            object_id = self.zmask_meta[ii]['info'].label
-            result_patch = mask_largest_object(obmap_patches.iat(ii))
+        for i in idx:
+            # object_id = self.zmask_meta[i]['info'].label
+            object_id = df.loc[i, 'label']
+            result_patch = mask_largest_object(obmap_patches.iat(i))
             object_class = np.unique(result_patch)[1]
             om[self.object_id_labels == object_id] = object_class
-            self.df[object_id, 'instance_class'] = object_class
+            se.loc[i] = object_class
 
         self.object_class_map = InMemoryDataAccessor(om)
 
@@ -262,7 +308,7 @@ def build_zmask_from_object_mask(
     h, w, c, nz = zstack.shape
 
     meta = []
-    for ob in df[df['keeper']].itertuples(name='LabeledObject'):
+    for ob in df.itertuples(name='LabeledObject'):
         y0 = max(ob.y0 - ebxy, 0)
         y1 = min(ob.y1 + ebxy, h)
         x0 = max(ob.x0 - ebxy, 0)
-- 
GitLab