diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 09fc6c424143ff77c39566df7229b59efededc8a..608e9d0965f5ea68f6186b36eb9c591dcb91f6fa 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -17,7 +17,7 @@ from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAc from model_server.base.models import InstanceSegmentationModel from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb from base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image -from model_server.extensions.chaeo.accessors import write_patch_to_file, MonoPatchStack, PatchStack +from model_server.extensions.chaeo.accessors import write_patch_to_file, PatchStack from base.process import mask_largest_object @@ -225,16 +225,15 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected + # TODO: remove, since padding is implicit in PatchStack + # TODO: test case where patch channel is restricted def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches if channel: patches_df = self.get_patches(white_channel=channel, pad_to=pad_to) else: patches_df = self.get_patches(pad_to=pad_to) patches = list(patches_df['patch']) - if channel is not None or self.acc_raw.chroma == 1: - return MonoPatchStack(patches) - else: - return PatchStack(patches) + return PatchStack(patches) def export_annotated_zstack(self, where, prefix='zstack', **kwargs): annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)) @@ -340,7 +339,7 @@ class RoiSet(object): return exported - def get_patch_masks(self, pad_to: int = 256) -> MonoPatchStack: + def get_patch_masks(self, pad_to: int = 256) -> PatchStack: patches = [] for roi in self: patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') @@ -350,7 +349,8 @@ class RoiSet(object): patch = pad(patch, pad_to) patches.append(patch) - return MonoPatchStack(patches) + return PatchStack(patches) + def get_patches( self, diff --git a/tests/test_roiset.py b/tests/test_roiset.py index df77c111ec69b372d6778465cc119d385aae3316..5aa3c28ef61c81022630a8db096868be58a43a91 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -158,6 +158,20 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) + def test_raw_patches_are_correct_shape(self): + roiset = self._make_roi_set() + patches = roiset.get_raw_patches() + np, h, w, nc, nz = patches.shape + self.assertEqual(np, roiset.count) + self.assertEqual(nc, roiset.acc_raw.chroma) + + def test_patch_masks_are_correct_shape(self): + roiset = self._make_roi_set() + patch_masks = roiset.get_patch_masks() + np, h, w, nc, nz = patch_masks.shape + self.assertEqual(np, roiset.count) + self.assertEqual(nc, 1) + class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): @@ -233,3 +247,5 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(result.chroma, self.stack.chroma) self.assertEqual(result.nz, self.stack.nz) + +