From 5bb914c7bf3446745f06a681fc592c30b7b01c66 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 21 Nov 2024 14:41:07 +0100
Subject: [PATCH] Optional intensity threshold for weight deprojection

---
 model_server/base/roiset.py | 20 ++++++++++++++++----
 1 file changed, 16 insertions(+), 4 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index a4397d5b..8dc5ff05 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -57,6 +57,7 @@ class RoiSetMetaParams(BaseModel):
     filters: Union[RoiFilter, None] = None
     expand_box_by: List[int] = [128, 0]
     deproject_channel: Union[int, None] = None
+    deproject_intensity_threshold: float = 0.0
 
 
 class RoiSetExportParams(BaseModel):
@@ -222,13 +223,22 @@ def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.Dat
     return dfbb
 
 
-def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_channel=None, filters=None) -> pd.DataFrame:
+def make_df_from_object_ids(
+        acc_raw,
+        acc_obj_ids,
+        expand_box_by,
+        deproject_channel=None,
+        filters=None,
+        deproject_intensity_threshold=0.0
+) -> pd.DataFrame:
     """
     Build dataframe that associate object IDs with summary stats;
     :param acc_raw: accessor to raw image data
     :param acc_obj_ids: accessor to map of object IDs
     :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary)
     :param deproject_channel: if objects' z-coordinates are not specified, compute them based on argmax of this channel
+    :param deproject_intensity_threshold: when deprojecting, round MIP deprojection_channel to zero if below this
+        threshold (as fraction of full range, 0.0 to 1.0)
     :return: pd.DataFrame
     """
     # build dataframe of objects, assign z index to each object
@@ -244,10 +254,11 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_chann
                 )
 
         mono = acc_raw.get_mono(deproject_channel)
-        mip = mono.get_mip()
+        intensity_weight = mono.get_mip().data_yx
+        intensity_weight[intensity_weight < (deproject_intensity_threshold * mono.dtype_max)] = 0
         zi_map = np.stack([
-            mip.data_yx,
-            mono.get_z_argmax().data_yx * mip.data_yx
+            intensity_weight,
+            mono.get_z_argmax().data_yx * intensity_weight,
         ], axis=-1)
 
         assert len(zi_map.shape) == 3
@@ -414,6 +425,7 @@ class RoiSet(object):
             acc_raw, acc_obj_ids,
             expand_box_by=params.expand_box_by,
             deproject_channel=params.deproject_channel,
+            deproject_intensity_threshold=params.deproject_intensity_threshold,
             filters=params.filters,
         )
 
-- 
GitLab