From 25a7d75b79f180e8a616472edd5b8f23bfe39d57 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 26 Nov 2024 06:58:27 +0100
Subject: [PATCH] Merged in changes from remote

---
 model_server/base/roiset.py | 17 +++++++++--------
 tests/base/test_roiset.py   | 15 +++++++++------
 2 files changed, 18 insertions(+), 14 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index bd94b33e..8190377a 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -19,7 +19,7 @@ from skimage.morphology import binary_dilation, disk
 
 from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 from .models import InstanceMaskSegmentationModel
-from .process import get_safe_contours, pad, rescale, make_rgb
+from .process import get_safe_contours, pad, rescale, make_rgb, resample_to_8bit
 from .annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
 from .accessors import generate_file_accessor, PatchStack
 from .process import mask_largest_object
@@ -42,6 +42,11 @@ class AnnotatedPatchParams(PatchParams):
     bounding_box_channel: int = 1
     bounding_box_linewidth: int = 2
 
+class RoiSetLabelsOverlayParams(BaseModel):
+    white_channel: int
+    transparency: float = 0.5
+    mip: bool = False
+    rescale_clip: Union[float, None] = None
 
 class AnnotatedZStackParams(BaseModel):
     draw_label: bool = False
@@ -67,10 +72,6 @@ class RoiSetMetaParams(BaseModel):
 
 
 class RoiSetExportParams(BaseModel):
-    class RoiSetLabelsOverlayParams(BaseModel):
-        transparency: float = 0.5
-        mip: bool = False
-        rescale_clip: Union[float, None] = None
     patches_3d: Union[PatchParams, None] = None
     annotated_patches_2d: Union[AnnotatedPatchParams, None] = None
     patches_2d: Union[PatchParams, None] = None
@@ -769,7 +770,7 @@ class RoiSet(object):
             fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
 
             if patch.dtype == 'uint16':
-                resampled = patch.apply(resample_to_8bit)
+                resampled = patch.to_8bit()
                 write_accessor_data_to_file(where / fname, resampled)
             else:
                 write_accessor_data_to_file(where / fname, patch)
@@ -997,7 +998,7 @@ class RoiSet(object):
                     dacc.export_pyxcz(fp)
                     record[k].append(str(fp))
             if k == 'labels_overlay':
-                fn = self.export_object_identities_overlay_map(subdir, prefix=prefix, white_channel=channel, **kp)
+                fn = self.export_object_identities_overlay_map(subdir, prefix=prefix, **kp)
                 record[k] = str(Path(k) / fn)
 
         # export dataframe and patch masks
@@ -1034,7 +1035,7 @@ class RoiSet(object):
                 for n in cnames:
                     interm[f'{k}_{n}'] = self.get_object_class_map(n)
             if k == 'labels_overlay':
-                interm[k] = self.get_object_identities_overlay_map(channel, **kp)
+                interm[k] = self.get_object_identities_overlay_map(**kp)
 
         return interm
 
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index c9d01b2c..9fc1ff74 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -373,8 +373,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
         # export z-stack product
         acc_out = self.roiset.get_export_product_accessors(
-            channel=-1,
-            params=RoiSetExportParams(**{'labels_overlay': {'transparency': 0.5}})
+            params=RoiSetExportParams(**{'labels_overlay': {'white_channel': -1, 'transparency': 0.5}})
         )['labels_overlay']
         self.assertEqual(acc_out.chroma, 3)
         self.assertGreater(acc_out.nz, 1)
@@ -382,8 +381,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
         # export MIP product
         acc_out = self.roiset.get_export_product_accessors(
-            channel=0,
-            params=RoiSetExportParams(**{'labels_overlay': {'transparency': 0.2, 'mip': True}})
+            params=RoiSetExportParams(**{'labels_overlay': {'white_channel': 0, 'transparency': 0.2, 'mip': True}})
         )['labels_overlay']
         self.assertEqual(acc_out.chroma, 3)
         self.assertEqual(acc_out.nz, 1)
@@ -392,9 +390,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         # export via .run_exports() method
         self.roiset.run_exports(
             where,
-            channel=-1,
             prefix='run_exports',
-            params=RoiSetExportParams(**{'labels_overlay': {'transparency': 0.2, 'mip': True, 'rescale_clip': 0.00}})
+            params=RoiSetExportParams(
+                **{
+                    'labels_overlay': {
+                        'white_channel': -1, 'transparency': 0.2, 'mip': True, 'rescale_clip': 0.00
+                    }
+                }
+            )
         )
 
 
-- 
GitLab