From 34a01200831d9622bfa101c2a60520a2c45fbbec Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 21 Dec 2023 15:03:23 +0100
Subject: [PATCH] Filters propagate from parameters dict

---
 extensions/chaeo/params.py            | 2 +-
 extensions/chaeo/tests/test_zstack.py | 6 +++++-
 extensions/chaeo/zmask.py             | 5 +++--
 3 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/extensions/chaeo/params.py b/extensions/chaeo/params.py
index 0ba956fb..a1f6aa8c 100644
--- a/extensions/chaeo/params.py
+++ b/extensions/chaeo/params.py
@@ -28,7 +28,7 @@ class RoiFilter(BaseModel):
 
 class RoiSetMetaParams(BaseModel):
     mask_type: str = 'boxes'
-    filters: RoiFilter = None
+    filters: RoiFilter = {}
     expand_box_by: List[int] = [128, 0]
 
 
diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py
index 7778370a..1779286b 100644
--- a/extensions/chaeo/tests/test_zstack.py
+++ b/extensions/chaeo/tests/test_zstack.py
@@ -211,7 +211,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
 
         roi_params = RoiSetMetaParams(**{
             'mask_type': 'boxes',
-            'filters': {},
+            'filters': {
+                'area': {'min': 1e3, 'max': 1e8}
+            },
             'expand_box_by': [128, 2]
         })
 
@@ -232,6 +234,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             'annotated_zstacks': {},
             'object_classes': True
         })
+
         infer_object_map_from_zstack(
             multichannel_zstack['path'],
             output_path / 'roiset' / 'workflow',
@@ -241,5 +244,6 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             segmentation_channel=pp['segmentation_channel'],
             patches_channel=pp['patches_channel'],
             export_params=export_params,
+            roi_params=roi_params,
         )
 
diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py
index 1441e9ae..053f1d44 100644
--- a/extensions/chaeo/zmask.py
+++ b/extensions/chaeo/zmask.py
@@ -174,9 +174,10 @@ def build_zmask_from_object_mask(
     lamap = label(obmask.data[:, :, 0, 0]).astype('uint16')
     query_str = 'label > 0'  # always true
     if filters is not None:
-        for k in filters.keys():
+        for k, val in filters.dict(exclude_unset=True).items():
             assert k in ('area', 'solidity')
-            vmin, vmax = filters[k]
+            vmin = val['min']
+            vmax = val['max']
             assert vmin >= 0
             query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
 
-- 
GitLab