From 1263a4e6cde070d73474005c6d98710b844f798b Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 4 Apr 2024 15:17:40 +0200
Subject: [PATCH] Allow None as an argument to rgb_overlay_channels, indicating
 export of mono patches

---
 model_server/base/roiset.py |  2 +-
 tests/test_roiset.py        | 31 ++++++++++++++++++++++++++++++-
 2 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index e50068d7..dd3139a1 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -28,7 +28,7 @@ class PatchParams(BaseModel):
     draw_mask: bool = False
     rescale_clip: float = 0.001
     focus_metric: str = 'max_sobel'
-    rgb_overlay_channels: List[Union[int, None]] = [None, None, None]
+    rgb_overlay_channels: Union[List[Union[int, None]], None] = None
     rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0]
     pad_to: int = 256
     expanded: bool = False
diff --git a/tests/test_roiset.py b/tests/test_roiset.py
index de58f25f..e281b7a0 100644
--- a/tests/test_roiset.py
+++ b/tests/test_roiset.py
@@ -428,9 +428,38 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             self.assertEqual(pacc.hw, (256, 256))
         print('res')
 
+    def test_run_export_mono_2d_patch(self):
+        p = RoiSetExportParams(**{
+            'patches_2d': {
+                'draw_bounding_box': False,
+                'draw_mask': False,
+                'expanded': True,
+                'pad_to': 256,
+                'rgb_overlay_channels': None,
+            },
+        })
+        self.assertTrue(hasattr(p.patches_2d, 'pad_to'))
+        self.assertTrue(hasattr(p.patches_2d, 'expanded'))
+
+        where = output_path / 'run_exports_mono_2d_patch'
+        res = self.roiset.run_exports(
+            where,
+            channel=-1,
+            prefix='test',
+            params=p
+        )
+
+        # test that exported patches are padded dimension
+        for fn in res['patches_2d']:
+            pa = where / fn
+            self.assertTrue(pa.exists())
+            pacc = generate_file_accessor(pa)
+            self.assertEqual(pacc.chroma, 1)
+        print('res')
+
 
 
-class TestRoiSetFromZmask(unittest.TestCase):
+class TestRoiSetSerialization(unittest.TestCase):
 
     def setUp(self) -> None:
         # set up test raw data and segmentation from file
-- 
GitLab