From 3bf6161ca420838381db33e35d8c186c90a9a0b5 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 16 Jul 2024 14:48:34 +0200
Subject: [PATCH] Removed image-flattening from focal points

---
 model_server/base/roiset.py | 26 +++++--------
 tests/base/test_roiset.py   | 78 ++++++++++++++++++++++++-------------
 2 files changed, 60 insertions(+), 44 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 360d86f9..23748e60 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -218,14 +218,14 @@ class RoiSet(object):
             df = pd.DataFrame(regionprops_table(
                 acc_obj_ids.data[:, :, 0, 0],
                 intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'),
-                properties=('label', 'area', 'intensity_mean', 'bbox', 'centroid')
+                properties=('label', 'area', 'intensity_mean', 'bbox')
             )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'})
             df['zi'] = df['intensity_mean'].round().astype('int')
 
         else:  # objects' z-coordinates come from arg of max count in object identities map
             df = pd.DataFrame(regionprops_table(
                 acc_obj_ids.data[:, :, 0, :],
-                properties=('label', 'area', 'bbox', 'centroid')
+                properties=('label', 'area', 'bbox')
             )).rename(columns={
                 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1'
             })
@@ -285,7 +285,14 @@ class RoiSet(object):
 
     # TODO: get overlapping segments
     def get_overlap_seg(self) -> pd.DataFrame:
-        df_overlap_bbox = self.get_overlap_bbox()
+        dfbb = _filter_overlap_bbox(self._df)
+        def _iou(roi_i):
+            roi1 = self._df.loc[roi_i.index]
+            roi2 = self._df.loc[roi_i.overlaps_with]
+            print(roi1)
+            print(roi2)
+
+        dfbb['iou'] = dfbb.apply()
 
 
     # TODO: test if overlaps exist
@@ -311,19 +318,6 @@ class RoiSet(object):
     def add_df_col(self, name, se: pd.Series) -> None:
         self._df[name] = se
 
-    def get_multichannel_projection(self):
-        if self.count:
-            projected = project_stack_from_focal_points(
-                self._df['centroid-0'].to_numpy(),
-                self._df['centroid-1'].to_numpy(),
-                self._df['zi'].to_numpy(),
-                self.acc_raw,
-                degree=4,
-            )
-        else:  # else just return MIP
-            projected = self.acc_raw.data.max(axis=-1)
-        return projected
-
     def get_patches_acc(self, channels: list = None, **kwargs) -> PatchStack:  # padded, un-annotated 2d patches
         if channels and len(channels) == 1:
             patches_df = self.get_patches(white_channel=channels[0], **kwargs)
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 6f39f239..b06499c9 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -19,20 +19,62 @@ params = conf.meta['roiset']
 
 class TestOverlapLogic(unittest.TestCase):
 
-    def test_overlap_bbox(self):
-        df = pd.DataFrame({
-            'x0': [0, 1, 3, 1, 1],
+    def setUp(self) -> None:
+        self.df = pd.DataFrame({
+            'x0': [0, 1, 2, 1, 1],
             'x1': [2, 3, 4, 3, 3],
-            'y0': [0, 0, 0, 1, 0],
-            'y1': [1, 1, 1, 2, 1],
+            'y0': [0, 0, 0, 2, 0],
+            'y1': [2, 2, 2, 3, 2],
             'zi': [0, 0, 0, 0, 1],
         })
 
-        res = _filter_overlap_bbox(df)
-        print(res)
+        self.mask = np.array([
+            [1, 1, 0, 0, 0],
+            [1, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+        ])
+
+        self.mask = np.array([
+            [0, 0, 1, 0, 0],
+            [0, 1, 1, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+        ])
+
+        self.mask = np.array([
+            [0, 0, 1, 1, 0],
+            [0, 0, 1, 1, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+            [0, 0, 0, 0, 0],
+        ])
+
+        # first and second overlap bounding boxes but not segmentation
+        # second and third overlap both ways
+        self.masks = [
+            [
+                [1, 1],
+                [1, 0]
+            ],
+            [
+                [0, 1],
+                [1, 1]
+            ],
+            [
+                [1, 1],
+                [1, 1]
+            ],
+        ]
 
-        self.assertEqual(len(res), 1)
+    def test_overlap_bbox(self):
+        res = _filter_overlap_bbox(self.df)
+        print(res)
+        self.assertEqual(len(res), 2)
         self.assertTrue((res.loc[0, 'overlaps_with'] == 1).all())
+        self.assertTrue((res.loc[1, 'overlaps_with'] == 2).all())
 
 
 class BaseTestRoiSetMonoProducts(object):
@@ -181,26 +223,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         result = generate_file_accessor(where / file)
         self.assertEqual(result.shape, roiset.acc_raw.shape)
 
-    def test_flatten_image(self):
-        roiset = RoiSet.from_binary_mask(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
-        df = roiset.get_df()
-
-        from model_server.base.roiset import project_stack_from_focal_points
-
-        img = project_stack_from_focal_points(
-            df['centroid-0'].to_numpy(),
-            df['centroid-1'].to_numpy(),
-            df['zi'].to_numpy(),
-            self.stack,
-            degree=4,
-        )
-
-        self.assertEqual(img.shape[0:2], self.stack.shape[0:2])
-
-        write_accessor_data_to_file(
-            output_path / 'flattened.tif',
-            InMemoryDataAccessor(img)
-        )
 
     def test_make_binary_masks(self):
         roiset = self._make_roi_set()
-- 
GitLab