From 87fb235032b9fdeff6944fdb587e528d794c37dc Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 16 Oct 2023 11:46:25 +0200
Subject: [PATCH] Accomodate mono TIFs; support 'uint8' binary masks

---
 conf/testing.py           | 10 ++++++++++
 model_server/accessors.py | 20 +++++++++++++++-----
 tests/test_accessors.py   |  8 ++++++--
 3 files changed, 31 insertions(+), 7 deletions(-)

diff --git a/conf/testing.py b/conf/testing.py
index 5f95b5ca..f2df5040 100644
--- a/conf/testing.py
+++ b/conf/testing.py
@@ -42,6 +42,16 @@ tifffile = {
     'z': 7,
 }
 
+filename = 'mono_zstack_mask.tif'
+monozstackmask = {
+    'filename': filename,
+    'path': root / filename,
+    'w': 256,
+    'h': 256,
+    'c': 1,
+    'z': 85
+}
+
 ilastik = {
     'pixel_classifier': 'demo_px.ilp',
     'object_classifier': 'demo_obj.ilp',
diff --git a/model_server/accessors.py b/model_server/accessors.py
index bd6cc11e..8ef4059e 100644
--- a/model_server/accessors.py
+++ b/model_server/accessors.py
@@ -34,7 +34,11 @@ class GenericImageDataAccessor(ABC):
         return True if self.shape_dict['Z'] > 1 else False
 
     def is_mask(self):
-        return self._data.dtype == 'bool'
+        if self._data.dtype == 'bool':
+            return True
+        elif self._data.dtype == 'uint8':
+            return np.all(np.unique(self._data) == [0, 255])
+        return False
 
     def get_one_channel_data (self, channel: int):
         c = int(channel)
@@ -101,12 +105,18 @@ class TifSingleSeriesFileAccessor(GenericImageFileAccessor):
             raise DataShapeError(f'Expect only one series in {fpath}')
 
         se = tf.series[0]
-        sd = {ch: se.shape[se.axes.index(ch)] for ch in se.axes}
 
-        idx = {k: sd[k] for k in ['Y', 'X', 'C', 'Z']}
+        order = ['Y', 'X', 'C', 'Z']
+        axs = [a for a in se.axes if a in order]
+        da = se.asarray()
+
+        if 'C' not in axs:
+            axs.append('C')
+            da = np.expand_dims(da, len(da.shape))
+
         yxcz = np.moveaxis(
-            se.asarray(),
-            [se.axes.index(ch) for ch in idx],
+            da,
+            [axs.index(k) for k in order],
             [0, 1, 2, 3]
         )
 
diff --git a/tests/test_accessors.py b/tests/test_accessors.py
index 4e1e7817..f7a4341d 100644
--- a/tests/test_accessors.py
+++ b/tests/test_accessors.py
@@ -2,7 +2,7 @@ import unittest
 
 import numpy as np
 
-from conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile
+from conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask
 from model_server.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor
 
 class TestCziImageFileAccess(unittest.TestCase):
@@ -106,4 +106,8 @@ class TestCziImageFileAccess(unittest.TestCase):
         self.assertEqual(acc.nz, 1)
 
     def test_read_mono_png(self):
-        return self.test_read_png(pngfile=monopngfile)
\ No newline at end of file
+        return self.test_read_png(pngfile=monopngfile)
+
+    def test_read_zstack_mono_mask(self):
+        acc = generate_file_accessor(monozstackmask['path'])
+        self.assertTrue(acc.is_mask())
\ No newline at end of file
-- 
GitLab