diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index ba35a9a97f7f6b83f9fb718a236b9e7269a9139d..caf27aaec4e0c7ecd3ced0eb6fc29c45481c5bc7 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -796,8 +796,11 @@ class RoiSet(object):
         patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded).copy()
 
         def _export_patch_mask(roi):
-            patch = InMemoryDataAccessor(roi.patch_mask)
-            ext = 'png'
+            patch = InMemoryDataAccessor.from_mono(roi.patch_mask)
+            if patch.nz == 1:
+                ext = 'png'
+            else:
+                ext = 'tif'
             fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
             write_accessor_data_to_file(where / fname, patch)
             return fname
@@ -835,16 +838,18 @@ class RoiSet(object):
             if pad_to:
                 patch = pad(patch, pad_to)
             if is_3d:
-                return np.expand_dims(patch, 3)
+                return patch
             else:
-                return np.expand_dims(patch, (2, 3))
+                return np.expand_dims(patch, 2)
 
         dfe = self._df.copy()
         dfe['patch_mask'] = dfe.apply(_make_patch_mask, axis=1)
         return dfe
 
     def get_patch_masks_acc(self, **kwargs) -> PatchStack:
-        return PatchStack(list(self.get_patch_masks(**kwargs).patch_mask))
+        se_pm = self.get_patch_masks(**kwargs).patch_mask
+        se_ext = se_pm.apply(lambda x: np.expand_dims(x, 2))
+        return PatchStack(list(se_ext))
 
     def get_patches(
             self,
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 46896538a47c5f9067465e0a75fb3c9890535ca7..5dc1fd986c193ac4cff99f1e9c77704eda853236 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -242,8 +242,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         roiset = self._make_roi_set()
         df_patch_masks = roiset.get_patch_masks()
         for roi in df_patch_masks.itertuples():
-            h, w, nc, nz = roi.patch_mask.shape
-            self.assertEqual(nc, 1)
+            h, w, nz = roi.patch_mask.shape
             self.assertEqual(nz, 1)
             self.assertEqual(h, roi.h)
             self.assertEqual(w, roi.w)