diff --git a/conf/testing.py b/conf/testing.py index 5f95b5caa314ac6ff315677136094135c983ff5b..f2df5040d83ede67bd06bb1c114e2e02bc0a0772 100644 --- a/conf/testing.py +++ b/conf/testing.py @@ -42,6 +42,16 @@ tifffile = { 'z': 7, } +filename = 'mono_zstack_mask.tif' +monozstackmask = { + 'filename': filename, + 'path': root / filename, + 'w': 256, + 'h': 256, + 'c': 1, + 'z': 85 +} + ilastik = { 'pixel_classifier': 'demo_px.ilp', 'object_classifier': 'demo_obj.ilp', diff --git a/model_server/accessors.py b/model_server/accessors.py index bd6cc11e6557ae95f5104cd0567881036c8ef1d8..8ef4059eb4d3d6a7a8373035e39bb32e5ffbb6b9 100644 --- a/model_server/accessors.py +++ b/model_server/accessors.py @@ -34,7 +34,11 @@ class GenericImageDataAccessor(ABC): return True if self.shape_dict['Z'] > 1 else False def is_mask(self): - return self._data.dtype == 'bool' + if self._data.dtype == 'bool': + return True + elif self._data.dtype == 'uint8': + return np.all(np.unique(self._data) == [0, 255]) + return False def get_one_channel_data (self, channel: int): c = int(channel) @@ -101,12 +105,18 @@ class TifSingleSeriesFileAccessor(GenericImageFileAccessor): raise DataShapeError(f'Expect only one series in {fpath}') se = tf.series[0] - sd = {ch: se.shape[se.axes.index(ch)] for ch in se.axes} - idx = {k: sd[k] for k in ['Y', 'X', 'C', 'Z']} + order = ['Y', 'X', 'C', 'Z'] + axs = [a for a in se.axes if a in order] + da = se.asarray() + + if 'C' not in axs: + axs.append('C') + da = np.expand_dims(da, len(da.shape)) + yxcz = np.moveaxis( - se.asarray(), - [se.axes.index(ch) for ch in idx], + da, + [axs.index(k) for k in order], [0, 1, 2, 3] ) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 4e1e7817d724dfd78216541c9960278cc632ecd6..f7a4341dc04d42aff05ed6f01f0113578590684a 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -2,7 +2,7 @@ import unittest import numpy as np -from conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile +from conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask from model_server.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor class TestCziImageFileAccess(unittest.TestCase): @@ -106,4 +106,8 @@ class TestCziImageFileAccess(unittest.TestCase): self.assertEqual(acc.nz, 1) def test_read_mono_png(self): - return self.test_read_png(pngfile=monopngfile) \ No newline at end of file + return self.test_read_png(pngfile=monopngfile) + + def test_read_zstack_mono_mask(self): + acc = generate_file_accessor(monozstackmask['path']) + self.assertTrue(acc.is_mask()) \ No newline at end of file