From 8ecc14be9d9c7008e02635fd403c299aab07b29c Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 16 Jul 2024 19:06:21 +0200
Subject: [PATCH] Generally removing necessity of instantiating with a map of
 object IDs

---
 model_server/base/roiset.py | 187 ++++++++++++++++++++++--------------
 tests/base/test_roiset.py   |   6 +-
 2 files changed, 118 insertions(+), 75 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 36f9e066..6afa4b1a 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -108,6 +108,19 @@ def _focus_metrics():
         'moment': lambda x: moment(x.flatten(), moment=2),
     }
 
+
+def _filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
+    query_str = 'label > 0'  # always true
+    if filters is not None:  # parse filters
+        for k, val in filters.dict(exclude_unset=True).items():
+            assert k in ('area')
+            vmin = val['min']
+            vmax = val['max']
+            assert vmin >= 0
+            query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
+    return df.loc[df.query(query_str).index, :]
+
+
 # TODO: get overlapping bounding boxes
 def _filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
 
@@ -128,9 +141,50 @@ def _filter_overlap_bbox(df: pd.DataFrame) -> pd.DataFrame:
     sdf['overlaps_with'] = second
     return sdf
 
+def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
+    """
+    Build dataframe associate object IDs with summary stats
+    :param acc_raw: accessor to raw image data
+    :param acc_obj_ids: accessor to map of object IDs
+    :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
+    # :param deproject: assign object's z-position based on argmax of raw data if True
+    :return: pd.DataFrame
+    """
+    # build dataframe of objects, assign z index to each object
+
+    if acc_obj_ids.nz == 1:  # deproject objects' z-coordinates from argmax of raw image
+        df = pd.DataFrame(regionprops_table(
+            acc_obj_ids.data[:, :, 0, 0],
+            intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'),
+            properties=('label', 'area', 'intensity_mean', 'bbox')
+        )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'})
+        df['zi'] = df['intensity_mean'].round().astype('int')
+
+    else:  # objects' z-coordinates come from arg of max count in object identities map
+        df = pd.DataFrame(regionprops_table(
+            acc_obj_ids.data[:, :, 0, :],
+            properties=('label', 'area', 'bbox')
+        )).rename(columns={
+            'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1'
+        })
+        df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax())
+
+    df = _df_insert_slices(df, acc_raw.shape_dict, expand_box_by)
+
+    # TODO: make this contingent on whether seg is included
+    df['binary_mask'] = df.apply(
+        lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0],
+        axis=1,
+        result_type='reduce',
+    )
+    return df
+
+
+def _df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame:
+    h = sd['Y']
+    w = sd['X']
+    nz = sd['Z']
 
-def _df_insert_slices(df: pd.DataFrame, shape: tuple, expand_box_by) -> pd.DataFrame:
-    h, w, c, nz = shape
     df['h'] = df['y1'] - df['y0']
     df['w'] = df['x1'] - df['x0']
     ebxy, ebz = expand_box_by
@@ -185,6 +239,18 @@ def _safe_add(a, g, b):
         np.iinfo(a.dtype).max
     ).astype(a.dtype)
 
+def _make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor:
+    id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16')
+    if 'binary_mask' not in df.columns:
+        raise MissingSegmentationError('RoiSet dataframe does not contain segmentation')
+
+    def _label_obj(r):
+        sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1]
+        id_mask[sl] = id_mask[sl] + r.label * r.binary_mask
+
+    df.apply(_label_obj, axis=1)
+    return InMemoryDataAccessor(id_mask)
+
 
 class RoiSet(object):
 
@@ -192,7 +258,8 @@ class RoiSet(object):
     def __init__(
             self,
             acc_raw: GenericImageDataAccessor,
-            acc_obj_ids: GenericImageDataAccessor,
+            # acc_obj_ids: GenericImageDataAccessor,
+            df: pd.DataFrame,
             params: RoiSetMetaParams = RoiSetMetaParams(),
     ):
         """
@@ -203,19 +270,20 @@ class RoiSet(object):
             labels its membership in a connected object
         :param params: optional arguments that influence the definition and representation of ROIs
         """
-        assert acc_obj_ids.chroma == 1
-        self.acc_obj_ids = acc_obj_ids
+        # assert acc_obj_ids.chroma == 1
+        # self.acc_obj_ids = acc_obj_ids
         self.acc_raw = acc_raw
         self.accs_derived = []
         self.params = params
 
-        self._df = self.filter_df(
-            self.make_df_from_object_ids(
-                self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by
-            ),
-            params.filters,
-        )
+        # self._df = self.filter_df(
+        #     self.make_df_from_object_ids(
+        #         self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by
+        #     ),
+        #     params.filters,
+        # )
 
+        self._df = df
         self.count = len(self._df)
         self.object_class_maps = {}  # classification results
 
@@ -223,6 +291,24 @@ class RoiSet(object):
         """Expose ROI meta information via the Pandas.DataFrame API"""
         return self._df.itertuples(name='Roi')
 
+    @staticmethod
+    def from_object_ids(
+            acc_raw: GenericImageDataAccessor,
+            acc_obj_ids: GenericImageDataAccessor,
+            params: RoiSetMetaParams = RoiSetMetaParams(),
+    ):
+        assert acc_obj_ids.chroma == 1
+
+        df = _filter_df(
+            _make_df_from_object_ids(
+                acc_raw, acc_obj_ids, expand_box_by=params.expand_box_by
+            ),
+            params.filters,
+        )
+
+        return RoiSet(acc_raw, df, params)
+
+
     # TODO: add or overload for object detection case
     @staticmethod
     def from_binary_mask(
@@ -241,50 +327,16 @@ class RoiSet(object):
         :param params: optional arguments that influence the definition and representation of ROIs
         :return: object identities map
         """
-        return RoiSet(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params)
+        return __class__.from_object_ids(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params)
 
 
     # TODO: generate overlapping RoiSet from multiple masks
     # call e.g. static adder
 
+
     @staticmethod
-    def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
-        """
-        Build dataframe associate object IDs with summary stats
-        :param acc_raw: accessor to raw image data
-        :param acc_obj_ids: accessor to map of object IDs
-        :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
-        # :param deproject: assign object's z-position based on argmax of raw data if True
-        :return: pd.DataFrame
-        """
-        # build dataframe of objects, assign z index to each object
-
-        if acc_obj_ids.nz == 1:  # deproject objects' z-coordinates from argmax of raw image
-            df = pd.DataFrame(regionprops_table(
-                acc_obj_ids.data[:, :, 0, 0],
-                intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'),
-                properties=('label', 'area', 'intensity_mean', 'bbox')
-            )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'})
-            df['zi'] = df['intensity_mean'].round().astype('int')
-
-        else:  # objects' z-coordinates come from arg of max count in object identities map
-            df = pd.DataFrame(regionprops_table(
-                acc_obj_ids.data[:, :, 0, :],
-                properties=('label', 'area', 'bbox')
-            )).rename(columns={
-                'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1'
-            })
-            df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax())
-
-        df = _df_insert_slices(df, acc_raw.shape, expand_box_by)
-
-        # TODO: make this contingent on whether seg is included
-        df['binary_mask'] = df.apply(
-            lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0],
-            axis=1,
-            result_type='reduce',
-        )
-        return df
+    def make_df_from_patches():
+        pass
 
 
     # TODO: get overlapping segments
@@ -298,21 +350,6 @@ class RoiSet(object):
 
         dfbb['iou'] = dfbb.apply()
 
-
-    # TODO: test if overlaps exist
-
-    @staticmethod
-    def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
-        query_str = 'label > 0'  # always true
-        if filters is not None:  # parse filters
-            for k, val in filters.dict(exclude_unset=True).items():
-                assert k in ('area')
-                vmin = val['min']
-                vmax = val['max']
-                assert vmin >= 0
-                query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
-        return df.loc[df.query(query_str).index, :]
-
     def get_df(self) -> pd.DataFrame:
         return self._df
 
@@ -731,26 +768,29 @@ class RoiSet(object):
 
         return {}
 
+    @property
+    def acc_obj_ids(self):
+        return _make_object_ids_from_df(self._df, self.acc_raw.shape_dict)
+
     # TODO: add docstring
     # TODO: make this work with obj det dataset
     @staticmethod
     def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''):
         df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']]
 
-        id_mask = np.zeros((*acc_raw.hw, 1, acc_raw.nz), dtype='uint16')
-        def _label_obj(r):
-            sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1]
+        def _read_binary_mask(r):
             ext = 'png'
             fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}'
             try:
                 ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname)
-                bool_mask = ma_acc.data / np.iinfo(ma_acc.data.dtype).max
-                id_mask[sl] = id_mask[sl] + r.label * bool_mask
+                assert ma_acc.chroma == 1 and ma_acc.nz == 1
+                mask_data = ma_acc.data / np.iinfo(ma_acc.data.dtype).max
+                return mask_data
             except Exception as e:
                 raise DeserializeRoiSet(e)
-
-        df.apply(_label_obj, axis=1)
-        return RoiSet(acc_raw, InMemoryDataAccessor(id_mask))
+        df['binary_mask'] = df.apply(_read_binary_mask, axis=1)
+        id_mask = _make_object_ids_from_df(df, acc_raw.shape_dict)
+        return RoiSet.from_object_ids(acc_raw, id_mask)
 
 
 def project_stack_from_focal_points(
@@ -808,4 +848,7 @@ class DeserializeRoiSet(Error):
     pass
 
 class DerivedChannelError(Error):
+    pass
+
+class MissingSegmentationError(Error):
     pass
\ No newline at end of file
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index dcf849bd..8e577f50 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -137,7 +137,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def test_create_roiset_with_no_objects(self):
         zero_obmap = InMemoryDataAccessor(np.zeros(self.seg_mask.shape, self.seg_mask.dtype))
-        roiset = RoiSet(self.stack_ch_pa, zero_obmap)
+        roiset = RoiSet.from_object_ids(self.stack_ch_pa, zero_obmap)
         self.assertEqual(roiset.count, 0)
 
     def test_slices_are_valid(self):
@@ -592,7 +592,7 @@ class TestRoiSetSerialization(unittest.TestCase):
         id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False)
         self.assertEqual(self.stack_ch_pa.shape, id_map.shape)
 
-        roiset = RoiSet(
+        roiset = RoiSet.from_object_ids(
             self.stack_ch_pa,
             id_map,
             params=RoiSetMetaParams(mask_type='contours')
@@ -605,7 +605,7 @@ class TestRoiSetSerialization(unittest.TestCase):
         self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3])
         self.assertEqual(id_map.nz, 1)
 
-        roiset = RoiSet(
+        roiset = RoiSet.from_object_ids(
             self.stack_ch_pa,
             id_map,
             params=RoiSetMetaParams(mask_type='contours')
-- 
GitLab