From e8461666465249eb505f21a0375d4eba537ddb65 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 21 Nov 2024 07:06:48 +0100
Subject: [PATCH] .to_8bit was causing problems when applied to an accessor
 that was already 8-bt

---
 model_server/base/accessors.py | 5 ++++-
 model_server/base/roiset.py    | 8 ++++++--
 2 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index 49666cfe..7e7affa1 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -205,7 +205,10 @@ class GenericImageDataAccessor(ABC):
         )
 
     def to_8bit(self):
-        return self.apply(resample_to_8bit, preserve_dtype=False)
+        if self.dtype == 'uint8':
+            return self
+        else:
+            return self.apply(resample_to_8bit, preserve_dtype=False)
 
 class InMemoryDataAccessor(GenericImageDataAccessor):
     def __init__(self, data):
diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 13ca95b7..e11450aa 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -63,6 +63,7 @@ class RoiSetExportParams(BaseModel):
     class RoiSetLabelsOverlayParams(BaseModel):
         transparency: float = 0.5
         mip: bool = False
+        rescale_clip: Union[float, None] = None
     patches_3d: Union[PatchParams, None] = None
     annotated_patches_2d: Union[PatchParams, None] = None
     patches_2d: Union[PatchParams, None] = None
@@ -688,8 +689,12 @@ class RoiSet(object):
             white_channel,
             transparency: float = 0.5,
             mip: bool = False,
+            rescale_clip: Union[float, None] = None,
     ) -> InMemoryDataAccessor:
-        mono = self.acc_raw.get_mono(channel=white_channel).to_8bit().data_yxz
+        mono = self.acc_raw.get_mono(channel=white_channel)
+        if rescale_clip is not None:
+            mono = mono.apply(lambda x: rescale(x, clip=rescale_clip))
+        mono = mono.to_8bit().data_yxz
         max_label = self.get_df()['label'].max()
         palette = np.array([[0, 0, 0]] + glasbey.create_palette(max_label, as_hex=False))
         rgb_8bit_palette = (255 * palette).round().astype('uint8')
@@ -699,7 +704,6 @@ class RoiSet(object):
             [0, 1, 3, 2],
             [0, 1, 2, 3]
         )
-
         combined = np.stack([mono, mono, mono], axis=2) + (1.0 - transparency) * id_map_yxcz
         combined_8bit = np.clip(combined, 0, 255).round().astype('uint8')
 
-- 
GitLab