From d21eb159b924582b88cf8feb7644fcd8508e4163 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 20 Nov 2024 17:59:11 +0100
Subject: [PATCH] Overlay map exports as multichannel z-stack

---
 model_server/base/roiset.py | 27 +++++++++++++++++++++------
 tests/base/test_roiset.py   |  6 ++++++
 2 files changed, 27 insertions(+), 6 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 65ce27b2..5e4a1eb2 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -68,7 +68,7 @@ class RoiSetExportParams(BaseModel):
     patches_2d: Union[PatchParams, None] = None
     annotated_zstacks: Union[AnnotatedZStackParams, None] = None
     object_classes: bool = False
-    labels_overlay: [RoiSetLabelsOverlayParams, None] = None
+    labels_overlay: Union[RoiSetLabelsOverlayParams, None] = None
     derived_channels: bool = False
 
 
@@ -689,11 +689,26 @@ class RoiSet(object):
             white_channel,
             transparency: float = 0.0
     ) -> InMemoryDataAccessor:
-        # TODO: convert mono obj_ids to RGB map with glasbey palette
-        rgb_ids = self.acc_obj_ids.data_yxz
-        self.acc_raw.get_mono(channel=white_channel).data_yxz
-        palette = (np.array(glasbey.create_palette(as_hex=False)) * 255)
-        palette.insert(palette, 0, [0, 0, 0], axis=0)
+        # TODO: just make resample_to_8bit accessor method
+        mono = self.acc_raw.get_mono(channel=white_channel).apply(resample_to_8bit, preserve_dtype=False).data_yxz
+        max_label = self.get_df()['label'].max()
+        palette = np.array([[0, 0, 0]] + glasbey.create_palette(max_label, as_hex=False))
+        rgb_8bit_palette = (255 * palette).round().astype('uint8')
+        id_map_yxzc = rgb_8bit_palette[self.acc_obj_ids.data_yxz]
+        id_map_yxcz = np.moveaxis(
+            id_map_yxzc,
+            [0, 1, 3, 2],
+            [0, 1, 2, 3]
+        )
+
+        # TODO: lift assertions
+        assert id_map_yxcz.shape[2] == 3
+        assert id_map_yxcz.shape[0:2] == mono.shape[0:2]
+        assert id_map_yxcz.shape[3] == mono.shape[2]
+        assert id_map_yxcz.dtype == mono.dtype
+
+        combined = np.stack([mono, mono, mono], axis=2) + id_map_yxcz
+        return InMemoryDataAccessor(combined)
 
     def get_serializable_dataframe(self) -> pd.DataFrame:
         return self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1)
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index ee73aa4b..cc6843fd 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -369,6 +369,12 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         self.assertEqual(result.chroma, 5)
         self.assertEqual(result.nz, 1)
 
+    def test_multichannel_to_label_ids_overlay(self):
+        where = output_path / 'multichannel' / 'multichannel_label_ids_overlay'
+        acc_out = self.roiset.get_object_identities_overlay_map(white_channel=3)
+        self.assertEqual(acc_out.chroma, 3)
+        acc_out.write(where / 'overlay.tif')
+
     def test_multichannel_annotated_zstack(self):
         where = output_path / 'multichannel' / 'annotated_zstack'
         file = self.roiset.export_annotated_zstack(
-- 
GitLab