From 05af7b71bd2d60ce9f4c605923c2fb50450da31c Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 12 Oct 2023 16:42:12 +0200
Subject: [PATCH] Export z-stack of masks, encoded with object class from
 EcoTaxa table

---
 conf/testing.py               | 20 ++++++++++++++++++++
 extensions/chaeo/workflows.py | 16 ++++++++++------
 model_server/accessors.py     | 21 ++++++++++++++++++---
 tests/test_accessors.py       | 13 +++++++++++--
 4 files changed, 59 insertions(+), 11 deletions(-)

diff --git a/conf/testing.py b/conf/testing.py
index a089ccc5..5f95b5ca 100644
--- a/conf/testing.py
+++ b/conf/testing.py
@@ -12,6 +12,26 @@ czifile = {
     'z': 1,
 }
 
+filename = 'rgb.png'
+rgbpngfile = {
+    'filename': filename,
+    'path': root / filename,
+    'w': 64,
+    'h': 128,
+    'c': 3,
+    'z': 1
+}
+
+filename = 'mono.png'
+monopngfile = {
+    'filename': filename,
+    'path': root / filename,
+    'w': 64,
+    'h': 128,
+    'c': 1,
+    'z': 1
+}
+
 filename = 'zmask-test-stack.tif'
 tifffile = {
     'filename': filename,
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index 47da00bc..d6fc26e1 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -201,7 +201,6 @@ def transfer_ecotaxa_labels_to_patch_stacks(
         'hierarchy': se_unique,
         'annotation_class': df_split.loc[:, 1].str.lower()
     })
-    df_labels.to_csv(Path(where_output) / 'labels_key.csv')
 
     df_pf = pd.merge(
         df_merge[['patch_filename', 'object_annotation_hierarchy']],
@@ -211,17 +210,22 @@ def transfer_ecotaxa_labels_to_patch_stacks(
     )
     df_pl = df_pf[df_pf['object_annotation_hierarchy'].notnull()]
 
-    zstack = np.zeros((*patch_size, 1, len(df_pl)), dtype='uint16')
+    zstack = np.zeros((*patch_size, 1, len(df_pl)), dtype='uint8')
+
+    df_labels['counts'] = df_pl['annotation_class_id'].value_counts()
+    df_labels.to_csv(Path(where_output) / 'labels_key.csv')
 
     # export patches as z-stack
     for fi, pl in enumerate(df_pl.itertuples(name='PatchFile')):
         fn = pl._asdict()['patch_filename']
         ac = pl._asdict()['annotation_class_id']
-        bm = generate_file_accessor(Path(where_masks) / fn).data
-        assert bm.shape == patch_size, f'Unexpected patch size {patch_size}'
-        zstack[:, :, 0, fi] = (bm == 255) * ac
+        acc_bm = generate_file_accessor(Path(where_masks) / fn)
+        assert acc_bm.hw == patch_size, f'Unexpected patch size {patch_size}'
+        assert acc_bm.chroma == 1
+        assert acc_bm.nz == 1
+        zstack[:, :, 0, fi] = (acc_bm.data[:, :, 0, 0] == 255) * ac
 
     # export masks as z-stack
-    write_accessor_data_to_file(where_output / 'zstack_object_label.tif', InMemoryDataAccessor(zstack))
+    write_accessor_data_to_file(Path(where_output) / 'zstack_object_label.tif', InMemoryDataAccessor(zstack))
 
 
diff --git a/model_server/accessors.py b/model_server/accessors.py
index ae04dce3..bd6cc11e 100644
--- a/model_server/accessors.py
+++ b/model_server/accessors.py
@@ -4,6 +4,7 @@ from pathlib import Path
 from typing import Dict
 
 import numpy as np
+from skimage.io import imread
 
 import czifile
 import tifffile
@@ -114,6 +115,20 @@ class TifSingleSeriesFileAccessor(GenericImageFileAccessor):
     def __del__(self):
         self.tf.close()
 
+class PngFileAccessor(GenericImageFileAccessor):
+    def __init__(self, fpath: Path):
+        super().__init__(fpath)
+
+        try:
+            arr = imread(fpath)
+        except Exception:
+            FileAccessorError(f'Unable to access data in {fpath}')
+
+        if len(arr.shape) == 3: # rgb
+            self._data = np.expand_dims(arr, 3)
+        else: # mono
+            self._data = np.expand_dims(arr, (2, 3))
+
 class CziImageFileAccessor(GenericImageFileAccessor):
     """
     Image that is stored in a Zeiss .CZI file; may be multi-channel, and/or a z-stack,
@@ -145,9 +160,7 @@ class CziImageFileAccessor(GenericImageFileAccessor):
         self.czifile.close()
 
 
-def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor, mkdir=True) -> bool:
-    if mkdir and not fpath.parent.exists():
-        fpath.parent.mkdir(parents=True)
+def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor) -> bool:
     try:
         zcyx= np.moveaxis(
             accessor.data, # yxcz
@@ -168,6 +181,8 @@ def generate_file_accessor(fpath):
         return TifSingleSeriesFileAccessor(fpath)
     elif str(fpath).upper().endswith('.CZI'):
         return CziImageFileAccessor(fpath)
+    elif str(fpath).upper().endswith('.PNG'):
+        return PngFileAccessor(fpath)
     else:
         raise FileAccessorError(f'Could not match a file accessor with {fpath}')
 
diff --git a/tests/test_accessors.py b/tests/test_accessors.py
index da39610a..4e1e7817 100644
--- a/tests/test_accessors.py
+++ b/tests/test_accessors.py
@@ -2,8 +2,8 @@ import unittest
 
 import numpy as np
 
-from conf.testing import czifile, output_path, tifffile
-from model_server.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor
+from conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile
+from model_server.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor
 
 class TestCziImageFileAccess(unittest.TestCase):
 
@@ -98,3 +98,12 @@ class TestCziImageFileAccess(unittest.TestCase):
         se = fh.series[0]
         fh_shape_dict = {se.axes[i]: se.shape[i] for i in range(0, len(se.shape))}
         self.assertEqual(fh_shape_dict, acc.shape_dict, 'Axes are not preserved in TIF output')
+
+    def test_read_png(self, pngfile=rgbpngfile):
+        acc = PngFileAccessor(pngfile['path'])
+        self.assertEqual(acc.hw, (pngfile['h'], pngfile['w']))
+        self.assertEqual(acc.chroma, pngfile['c'])
+        self.assertEqual(acc.nz, 1)
+
+    def test_read_mono_png(self):
+        return self.test_read_png(pngfile=monopngfile)
\ No newline at end of file
-- 
GitLab