From a857e683e5deb2701522a3f7ede1dbef17b5f5c0 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 12 Oct 2023 16:29:22 +0200
Subject: [PATCH] Support mono PNG too

---
 conf/testing.py           | 14 ++++++++++++--
 model_server/accessors.py |  5 ++++-
 tests/test_accessors.py   |  9 ++++++---
 3 files changed, 22 insertions(+), 6 deletions(-)

diff --git a/conf/testing.py b/conf/testing.py
index 59d642b8..5f95b5ca 100644
--- a/conf/testing.py
+++ b/conf/testing.py
@@ -12,8 +12,8 @@ czifile = {
     'z': 1,
 }
 
-filename = 'test_img.png'
-pngfile = {
+filename = 'rgb.png'
+rgbpngfile = {
     'filename': filename,
     'path': root / filename,
     'w': 64,
@@ -22,6 +22,16 @@ pngfile = {
     'z': 1
 }
 
+filename = 'mono.png'
+monopngfile = {
+    'filename': filename,
+    'path': root / filename,
+    'w': 64,
+    'h': 128,
+    'c': 1,
+    'z': 1
+}
+
 filename = 'zmask-test-stack.tif'
 tifffile = {
     'filename': filename,
diff --git a/model_server/accessors.py b/model_server/accessors.py
index 84d29708..bd6cc11e 100644
--- a/model_server/accessors.py
+++ b/model_server/accessors.py
@@ -124,7 +124,10 @@ class PngFileAccessor(GenericImageFileAccessor):
         except Exception:
             FileAccessorError(f'Unable to access data in {fpath}')
 
-        self._data = np.expand_dims(arr, 3)
+        if len(arr.shape) == 3: # rgb
+            self._data = np.expand_dims(arr, 3)
+        else: # mono
+            self._data = np.expand_dims(arr, (2, 3))
 
 class CziImageFileAccessor(GenericImageFileAccessor):
     """
diff --git a/tests/test_accessors.py b/tests/test_accessors.py
index 426afa07..4e1e7817 100644
--- a/tests/test_accessors.py
+++ b/tests/test_accessors.py
@@ -2,7 +2,7 @@ import unittest
 
 import numpy as np
 
-from conf.testing import czifile, output_path, pngfile, tifffile
+from conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile
 from model_server.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor
 
 class TestCziImageFileAccess(unittest.TestCase):
@@ -99,8 +99,11 @@ class TestCziImageFileAccess(unittest.TestCase):
         fh_shape_dict = {se.axes[i]: se.shape[i] for i in range(0, len(se.shape))}
         self.assertEqual(fh_shape_dict, acc.shape_dict, 'Axes are not preserved in TIF output')
 
-    def test_read_rgb_png(self):
+    def test_read_png(self, pngfile=rgbpngfile):
         acc = PngFileAccessor(pngfile['path'])
         self.assertEqual(acc.hw, (pngfile['h'], pngfile['w']))
-        self.assertEqual(acc.chroma, 3)
+        self.assertEqual(acc.chroma, pngfile['c'])
         self.assertEqual(acc.nz, 1)
+
+    def test_read_mono_png(self):
+        return self.test_read_png(pngfile=monopngfile)
\ No newline at end of file
-- 
GitLab