Skip to content
Snippets Groups Projects
Commit d992c9d2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Consolidate usages of PatchStack in RoiSet class

parent eb0b6e03
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,7 @@ from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAc ...@@ -17,7 +17,7 @@ from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAc
from model_server.base.models import InstanceSegmentationModel from model_server.base.models import InstanceSegmentationModel
from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb from model_server.base.process import pad, rescale, resample_to_8bit, make_rgb
from base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image from base.annotators import draw_box_on_patch, draw_contours_on_patch, draw_boxes_on_3d_image
from model_server.extensions.chaeo.accessors import write_patch_to_file, MonoPatchStack, PatchStack from model_server.extensions.chaeo.accessors import write_patch_to_file, PatchStack
from base.process import mask_largest_object from base.process import mask_largest_object
...@@ -225,16 +225,15 @@ class RoiSet(object): ...@@ -225,16 +225,15 @@ class RoiSet(object):
projected = self.acc_raw.data.max(axis=-1) projected = self.acc_raw.data.max(axis=-1)
return projected return projected
# TODO: remove, since padding is implicit in PatchStack
# TODO: test case where patch channel is restricted
def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches
if channel: if channel:
patches_df = self.get_patches(white_channel=channel, pad_to=pad_to) patches_df = self.get_patches(white_channel=channel, pad_to=pad_to)
else: else:
patches_df = self.get_patches(pad_to=pad_to) patches_df = self.get_patches(pad_to=pad_to)
patches = list(patches_df['patch']) patches = list(patches_df['patch'])
if channel is not None or self.acc_raw.chroma == 1: return PatchStack(patches)
return MonoPatchStack(patches)
else:
return PatchStack(patches)
def export_annotated_zstack(self, where, prefix='zstack', **kwargs): def export_annotated_zstack(self, where, prefix='zstack', **kwargs):
annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)) annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs))
...@@ -340,7 +339,7 @@ class RoiSet(object): ...@@ -340,7 +339,7 @@ class RoiSet(object):
return exported return exported
def get_patch_masks(self, pad_to: int = 256) -> MonoPatchStack: def get_patch_masks(self, pad_to: int = 256) -> PatchStack:
patches = [] patches = []
for roi in self: for roi in self:
patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
...@@ -350,7 +349,8 @@ class RoiSet(object): ...@@ -350,7 +349,8 @@ class RoiSet(object):
patch = pad(patch, pad_to) patch = pad(patch, pad_to)
patches.append(patch) patches.append(patch)
return MonoPatchStack(patches) return PatchStack(patches)
def get_patches( def get_patches(
self, self,
......
...@@ -158,6 +158,20 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ...@@ -158,6 +158,20 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
def test_raw_patches_are_correct_shape(self):
roiset = self._make_roi_set()
patches = roiset.get_raw_patches()
np, h, w, nc, nz = patches.shape
self.assertEqual(np, roiset.count)
self.assertEqual(nc, roiset.acc_raw.chroma)
def test_patch_masks_are_correct_shape(self):
roiset = self._make_roi_set()
patch_masks = roiset.get_patch_masks()
np, h, w, nc, nz = patch_masks.shape
self.assertEqual(np, roiset.count)
self.assertEqual(nc, 1)
class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
...@@ -233,3 +247,5 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -233,3 +247,5 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
self.assertEqual(result.chroma, self.stack.chroma) self.assertEqual(result.chroma, self.stack.chroma)
self.assertEqual(result.nz, self.stack.nz) self.assertEqual(result.nz, self.stack.nz)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment