From d4ce55f1a38c555f1d8b904bfe2c1b48eca16c58 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 6 Feb 2025 09:57:41 +0100
Subject: [PATCH] Reorganizing things in RoiSet

---
 model_server/base/rois/__init__.py |   0
 model_server/base/rois/df.py       | 124 ++++++++++++++++++++++++++++
 model_server/base/roiset.py        | 126 +----------------------------
 tests/base/test_roiset.py          |   3 +-
 4 files changed, 129 insertions(+), 124 deletions(-)
 create mode 100644 model_server/base/rois/__init__.py
 create mode 100644 model_server/base/rois/df.py

diff --git a/model_server/base/rois/__init__.py b/model_server/base/rois/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/model_server/base/rois/df.py b/model_server/base/rois/df.py
new file mode 100644
index 00000000..69ca70f8
--- /dev/null
+++ b/model_server/base/rois/df.py
@@ -0,0 +1,124 @@
+import itertools
+
+import numpy as np
+import pandas as pd
+
+def filter_df(df: pd.DataFrame, filters: dict = {}) -> pd.DataFrame:
+    query_str = 'label > 0'  # always true
+    if filters is not None:  # parse filters
+        for k, val in filters.items():
+            assert k in ('area', 'diag', 'min_hw')
+            if val is None:
+                continue
+            vmin = val['min']
+            vmax = val['max']
+            assert vmin >= 0
+            query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
+    return df.loc[df.bounding_box.query(query_str).index, :]
+
+
+def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame:
+    """
+    If passed a single DataFrame, return the subset whose bounding boxes overlap in 3D space.  If passed two DataFrames,
+    return the subset where a ROI in the first overlaps a ROI in the second.  May return duplicates entries where a ROI
+    overlaps with multiple neighbors.
+    :param df1: DataFrame with potentially overlapping bounding boxes
+    :param df2: (optional) second DataFrame
+    :return DataFrame describing subset of overlapping ROIs
+        bbox_overlaps_with: index of ROI that overlaps
+        bbox_intersec: pixel area of intersecting region
+    """
+
+    def _compare(r0, r1):
+        olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0)
+        oly = (r0.y0 < r1.y1) and (r0.y1 > r1.y0)
+        olz = (r0.zi == r1.zi)
+        return olx and oly and olz
+
+    def _intersec(r0, r1):
+        return (r0.x1 - r1.x0) * (r0.y1 - r1.y0)
+
+    first = []
+    second = []
+    intersec = []
+
+    if df2 is not None:
+        for pair in itertools.product(df1.index, df2.index):
+            if _compare(
+                    df1.bounding_box.loc[pair[0]],
+                    df2.bounding_box.loc[pair[1]]
+            ):
+                first.append(pair[0])
+                second.append(pair[1])
+                intersec.append(
+                    _intersec(
+                        df1.bounding_box.loc[pair[0]],
+                        df2.bounding_box.loc[pair[1]]
+                    )
+                )
+    else:
+        for pair in itertools.combinations(df1.index, 2):
+            if _compare(
+                    df1.bounding_box.loc[pair[0]],
+                    df1.bounding_box.loc[pair[1]]
+            ):
+                first.append(pair[0])
+                second.append(pair[1])
+                first.append(pair[1])
+                second.append(pair[0])
+                isc = _intersec(
+                    df1.bounding_box.loc[pair[0]],
+                    df1.bounding_box.loc[pair[1]]
+                )
+                intersec.append(isc)
+                intersec.append(isc)
+
+    sdf = df1.bounding_box.loc[first]
+    sdf.loc[:, 'overlaps_with'] = second
+    sdf.loc[:, 'bbox_intersec'] = intersec
+    return sdf
+
+
+def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame:
+    """
+    If passed a single DataFrame, return the subset whose segmentations overlap in 3D space.  If passed two DataFrames,
+    return the subset where a ROI in the first overlaps a ROI in the second.  May return duplicates entries where a ROI
+    overlaps with multiple neighbors.
+    :param df1: DataFrame with potentially overlapping bounding boxes
+    :param df2: (optional) second DataFrame
+    :return DataFrame describing subset of overlapping ROIs
+        seg_overlaps_with: index of ROI that overlaps
+        seg_intersec: pixel area of intersecting region
+        seg_iou: intersection over union
+    """
+
+    dfbb = filter_df_overlap_bbox(df1, df2)
+
+    def _overlap_seg(r):
+        roi1 = df1.loc[r.name]
+        if df2 is not None:
+            roi2 = df2.loc[r.overlaps_with]
+        else:
+            roi2 = df1.loc[r.overlaps_with]
+        bb1 = roi1.bounding_box
+        bb2 = roi2.bounding_box
+        ex0 = min(bb1.x0, bb2.x0, bb1.x1, bb2.x1)
+        ew = max(bb1.x0, bb2.x0, bb1.x1, bb2.x1) - ex0
+        ey0 = min(bb1.y0, bb2.y0, bb1.y1, bb2.y1)
+        eh = max(bb1.y0, bb2.y0, bb1.y1, bb2.y1) - ey0
+        emask = np.zeros((eh, ew), dtype='uint8')
+        sl1 = np.s_[(bb1.y0 - ey0): (bb1.y1 - ey0), (bb1.x0 - ex0): (bb1.x1 - ex0)]
+        sl2 = np.s_[(bb2.y0 - ey0): (bb2.y1 - ey0), (bb2.x0 - ex0): (bb2.x1 - ex0)]
+        emask[sl1] = roi1.masks.binary_mask
+        emask[sl2] = emask[sl2] + roi2.masks.binary_mask
+        return emask
+
+    emasks = dfbb.apply(_overlap_seg, axis=1)
+    dfbb['seg_overlaps'] = emasks.apply(lambda x: np.any(x > 1))
+    dfbb['seg_intersec'] = emasks.apply(lambda x: (x == 2).sum())
+    dfbb['seg_iou'] = emasks.apply(lambda x: (x == 2).sum() / (x > 0).sum())
+    return dfbb
+
+
+def is_df_3d(df: pd.DataFrame) -> bool:
+    return 'z0' in df.bounding_box.columns and 'z1' in df.bounding_box.columns
diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 244acc65..bf729a6a 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -1,5 +1,4 @@
 from collections import OrderedDict
-import itertools
 from math import sqrt, floor
 from pathlib import Path
 from typing import Dict, List, Union
@@ -23,6 +22,7 @@ from .process import get_safe_contours, pad, rescale, make_rgb
 from .annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
 from .accessors import generate_file_accessor, PatchStack
 from .process import mask_largest_object
+from .rois.df import filter_df, filter_df_overlap_seg, is_df_3d
 
 
 class PatchParams(BaseModel):
@@ -140,127 +140,6 @@ def focus_metrics():
     }
 
 
-def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame:
-    query_str = 'label > 0'  # always true
-    if filters is not None:  # parse filters
-        for k, val in filters.dict(exclude_unset=True).items():
-            assert k in ('area', 'diag', 'min_hw')
-            if val is None:
-                continue
-            vmin = val['min']
-            vmax = val['max']
-            assert vmin >= 0
-            query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
-    return df.loc[df.bounding_box.query(query_str).index, :]
-
-
-def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame:
-    """
-    If passed a single DataFrame, return the subset whose bounding boxes overlap in 3D space.  If passed two DataFrames,
-    return the subset where a ROI in the first overlaps a ROI in the second.  May return duplicates entries where a ROI
-    overlaps with multiple neighbors.
-    :param df1: DataFrame with potentially overlapping bounding boxes
-    :param df2: (optional) second DataFrame
-    :return DataFrame describing subset of overlapping ROIs
-        bbox_overlaps_with: index of ROI that overlaps
-        bbox_intersec: pixel area of intersecting region
-    """
-
-    def _compare(r0, r1):
-        olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0)
-        oly = (r0.y0 < r1.y1) and (r0.y1 > r1.y0)
-        olz = (r0.zi == r1.zi)
-        return olx and oly and olz
-
-    def _intersec(r0, r1):
-        return (r0.x1 - r1.x0) * (r0.y1 - r1.y0)
-
-    first = []
-    second = []
-    intersec = []
-
-    if df2 is not None:
-        for pair in itertools.product(df1.index, df2.index):
-            if _compare(
-                    df1.bounding_box.loc[pair[0]],
-                    df2.bounding_box.loc[pair[1]]
-            ):
-                first.append(pair[0])
-                second.append(pair[1])
-                intersec.append(
-                    _intersec(
-                        df1.bounding_box.loc[pair[0]],
-                        df2.bounding_box.loc[pair[1]]
-                    )
-                )
-    else:
-        for pair in itertools.combinations(df1.index, 2):
-            if _compare(
-                    df1.bounding_box.loc[pair[0]],
-                    df1.bounding_box.loc[pair[1]]
-            ):
-                first.append(pair[0])
-                second.append(pair[1])
-                first.append(pair[1])
-                second.append(pair[0])
-                isc = _intersec(
-                    df1.bounding_box.loc[pair[0]],
-                    df1.bounding_box.loc[pair[1]]
-                )
-                intersec.append(isc)
-                intersec.append(isc)
-
-    sdf = df1.bounding_box.loc[first]
-    sdf.loc[:, 'overlaps_with'] = second
-    sdf.loc[:, 'bbox_intersec'] = intersec
-    return sdf
-
-
-def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame:
-    """
-    If passed a single DataFrame, return the subset whose segmentations overlap in 3D space.  If passed two DataFrames,
-    return the subset where a ROI in the first overlaps a ROI in the second.  May return duplicates entries where a ROI
-    overlaps with multiple neighbors.
-    :param df1: DataFrame with potentially overlapping bounding boxes
-    :param df2: (optional) second DataFrame
-    :return DataFrame describing subset of overlapping ROIs
-        seg_overlaps_with: index of ROI that overlaps
-        seg_intersec: pixel area of intersecting region
-        seg_iou: intersection over union
-    """
-
-    dfbb = filter_df_overlap_bbox(df1, df2)
-
-    def _overlap_seg(r):
-        roi1 = df1.loc[r.name]
-        if df2 is not None:
-            roi2 = df2.loc[r.overlaps_with]
-        else:
-            roi2 = df1.loc[r.overlaps_with]
-        bb1 = roi1.bounding_box
-        bb2 = roi2.bounding_box
-        ex0 = min(bb1.x0, bb2.x0, bb1.x1, bb2.x1)
-        ew = max(bb1.x0, bb2.x0, bb1.x1, bb2.x1) - ex0
-        ey0 = min(bb1.y0, bb2.y0, bb1.y1, bb2.y1)
-        eh = max(bb1.y0, bb2.y0, bb1.y1, bb2.y1) - ey0
-        emask = np.zeros((eh, ew), dtype='uint8')
-        sl1 = np.s_[(bb1.y0 - ey0): (bb1.y1 - ey0), (bb1.x0 - ex0): (bb1.x1 - ex0)]
-        sl2 = np.s_[(bb2.y0 - ey0): (bb2.y1 - ey0), (bb2.x0 - ex0): (bb2.x1 - ex0)]
-        emask[sl1] = roi1.masks.binary_mask
-        emask[sl2] = emask[sl2] + roi2.masks.binary_mask
-        return emask
-
-    emasks = dfbb.apply(_overlap_seg, axis=1)
-    dfbb['seg_overlaps'] = emasks.apply(lambda x: np.any(x > 1))
-    dfbb['seg_intersec'] = emasks.apply(lambda x: (x == 2).sum())
-    dfbb['seg_iou'] = emasks.apply(lambda x: (x == 2).sum() / (x > 0).sum())
-    return dfbb
-
-
-def is_df_3d(df: pd.DataFrame) -> bool:
-    return 'z0' in df.bounding_box.columns and 'z1' in df.bounding_box.columns
-
-
 def make_df_from_object_ids(
         acc_raw,
         acc_obj_ids,
@@ -360,7 +239,8 @@ def make_df_from_object_ids(
     df = df.set_index('label')
     insert_level(df, 'bounding_box')
     df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by)
-    df_fil = filter_df(df, filters)
+    filters_dict = {} if filters is None else filters.dict(exclude_unset=True)
+    df_fil = filter_df(df, filters_dict)
     df_fil['masks', 'binary_mask'] = df_fil.bounding_box.apply(
         _make_binary_mask,
         axis=1,
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index fc017505..ade3a5ef 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -7,7 +7,8 @@ from pathlib import Path
 import pandas as pd
 
 from model_server.base.process import smooth
-from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, IntensityThresholdInstanceMaskSegmentationModel, read_roiset_df, RoiSet, RoiSetExportParams, RoiSetMetaParams
+from model_server.base.roiset import IntensityThresholdInstanceMaskSegmentationModel, read_roiset_df, RoiSet, RoiSetExportParams, RoiSetMetaParams
+from model_server.base.rois.df import filter_df_overlap_bbox, filter_df_overlap_seg
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
 import model_server.conf.testing as conf
 from model_server.conf.testing import DummyInstanceMaskSegmentationModel
-- 
GitLab