diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 09fc6c424143ff77c39566df7229b59efededc8a..608e9d0965f5ea68f6186b36eb9c591dcb91f6fa 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -17,7 +17,7 @@ from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAc
 from model_server.base.models import InstanceSegmentationModel
 from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb
 from base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
-from model_server.extensions.chaeo.accessors import write_patch_to_file, MonoPatchStack, PatchStack
+from model_server.extensions.chaeo.accessors import write_patch_to_file, PatchStack
 from base.process import mask_largest_object
 
 
@@ -225,16 +225,15 @@ class RoiSet(object):
             projected = self.acc_raw.data.max(axis=-1)
         return projected
 
+    # TODO: remove, since padding is implicit in PatchStack
+    # TODO: test case where patch channel is restricted
     def get_raw_patches(self, channel=None, pad_to=256, make_3d=False):  # padded, un-annotated 2d patches
         if channel:
             patches_df = self.get_patches(white_channel=channel, pad_to=pad_to)
         else:
             patches_df = self.get_patches(pad_to=pad_to)
         patches = list(patches_df['patch'])
-        if channel is not None or self.acc_raw.chroma == 1:
-            return MonoPatchStack(patches)
-        else:
-            return PatchStack(patches)
+        return PatchStack(patches)
 
     def export_annotated_zstack(self, where, prefix='zstack', **kwargs):
         annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs))
@@ -340,7 +339,7 @@ class RoiSet(object):
 
         return exported
 
-    def get_patch_masks(self, pad_to: int = 256) -> MonoPatchStack:
+    def get_patch_masks(self, pad_to: int = 256) -> PatchStack:
         patches = []
         for roi in self:
             patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
@@ -350,7 +349,8 @@ class RoiSet(object):
                 patch = pad(patch, pad_to)
 
             patches.append(patch)
-        return MonoPatchStack(patches)
+        return PatchStack(patches)
+
 
     def get_patches(
             self,
diff --git a/tests/test_roiset.py b/tests/test_roiset.py
index df77c111ec69b372d6778465cc119d385aae3316..5aa3c28ef61c81022630a8db096868be58a43a91 100644
--- a/tests/test_roiset.py
+++ b/tests/test_roiset.py
@@ -158,6 +158,20 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
         self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
 
+    def test_raw_patches_are_correct_shape(self):
+        roiset = self._make_roi_set()
+        patches = roiset.get_raw_patches()
+        np, h, w, nc, nz = patches.shape
+        self.assertEqual(np, roiset.count)
+        self.assertEqual(nc, roiset.acc_raw.chroma)
+
+    def test_patch_masks_are_correct_shape(self):
+        roiset = self._make_roi_set()
+        patch_masks = roiset.get_patch_masks()
+        np, h, w, nc, nz = patch_masks.shape
+        self.assertEqual(np, roiset.count)
+        self.assertEqual(nc, 1)
+
 
 class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
@@ -233,3 +247,5 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         self.assertEqual(result.chroma, self.stack.chroma)
         self.assertEqual(result.nz, self.stack.nz)
 
+
+