From d21eb159b924582b88cf8feb7644fcd8508e4163 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Wed, 20 Nov 2024 17:59:11 +0100 Subject: [PATCH] Overlay map exports as multichannel z-stack --- model_server/base/roiset.py | 27 +++++++++++++++++++++------ tests/base/test_roiset.py | 6 ++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 65ce27b2..5e4a1eb2 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -68,7 +68,7 @@ class RoiSetExportParams(BaseModel): patches_2d: Union[PatchParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None object_classes: bool = False - labels_overlay: [RoiSetLabelsOverlayParams, None] = None + labels_overlay: Union[RoiSetLabelsOverlayParams, None] = None derived_channels: bool = False @@ -689,11 +689,26 @@ class RoiSet(object): white_channel, transparency: float = 0.0 ) -> InMemoryDataAccessor: - # TODO: convert mono obj_ids to RGB map with glasbey palette - rgb_ids = self.acc_obj_ids.data_yxz - self.acc_raw.get_mono(channel=white_channel).data_yxz - palette = (np.array(glasbey.create_palette(as_hex=False)) * 255) - palette.insert(palette, 0, [0, 0, 0], axis=0) + # TODO: just make resample_to_8bit accessor method + mono = self.acc_raw.get_mono(channel=white_channel).apply(resample_to_8bit, preserve_dtype=False).data_yxz + max_label = self.get_df()['label'].max() + palette = np.array([[0, 0, 0]] + glasbey.create_palette(max_label, as_hex=False)) + rgb_8bit_palette = (255 * palette).round().astype('uint8') + id_map_yxzc = rgb_8bit_palette[self.acc_obj_ids.data_yxz] + id_map_yxcz = np.moveaxis( + id_map_yxzc, + [0, 1, 3, 2], + [0, 1, 2, 3] + ) + + # TODO: lift assertions + assert id_map_yxcz.shape[2] == 3 + assert id_map_yxcz.shape[0:2] == mono.shape[0:2] + assert id_map_yxcz.shape[3] == mono.shape[2] + assert id_map_yxcz.dtype == mono.dtype + + combined = np.stack([mono, mono, mono], axis=2) + id_map_yxcz + return InMemoryDataAccessor(combined) def get_serializable_dataframe(self) -> pd.DataFrame: return self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1) diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index ee73aa4b..cc6843fd 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -369,6 +369,12 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(result.chroma, 5) self.assertEqual(result.nz, 1) + def test_multichannel_to_label_ids_overlay(self): + where = output_path / 'multichannel' / 'multichannel_label_ids_overlay' + acc_out = self.roiset.get_object_identities_overlay_map(white_channel=3) + self.assertEqual(acc_out.chroma, 3) + acc_out.write(where / 'overlay.tif') + def test_multichannel_annotated_zstack(self): where = output_path / 'multichannel' / 'annotated_zstack' file = self.roiset.export_annotated_zstack( -- GitLab