Skip to content
Snippets Groups Projects
Commit addec54f authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

By default, return non-expanded patches and masks; allow contour and mask...

By default, return non-expanded patches and masks; allow contour and mask overlays in non-expanded patches
parent 3854679d
No related branches found
No related tags found
No related merge requests found
......@@ -200,7 +200,7 @@ class RoiSet(object):
df['slice'] = df.apply(
lambda r:
np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1) + 1],
np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)],
axis=1
)
df['expanded_slice'] = df.apply(
......@@ -279,14 +279,12 @@ class RoiSet(object):
projected = self.acc_raw.data.max(axis=-1)
return projected
# TODO: rename get_patches_acc
def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches
def get_patches_acc(self, channel=None, **kwargs): # padded, un-annotated 2d patches
if channel:
patches_df = self.get_patches(white_channel=channel, pad_to=pad_to)
patches_df = self.get_patches(white_channel=channel, **kwargs)
else:
patches_df = self.get_patches(pad_to=pad_to)
patches = list(patches_df['patch'])
return PatchStack(patches)
patches_df = self.get_patches(**kwargs)
return PatchStack(list(patches_df.patch))
def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> Path:
annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs))
......@@ -335,8 +333,8 @@ class RoiSet(object):
# do this on a patch basis, i.e. only one object per frame
obmap_patches = object_classification_model.label_patch_stack(
self.get_raw_patches(channel=channel, pad_to=256),
self.get_patch_masks_acc(expanded=True, pad_to=256)
self.get_patches_acc(channel=channel, expaned=False, pad_to=None),
self.get_patch_masks_acc(expanded=False, pad_to=None)
)
om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype)
......@@ -419,13 +417,13 @@ class RoiSet(object):
def get_patches(
self,
rescale_clip: float = 0.0,
pad_to: int = 256,
pad_to: int = None,
make_3d: bool = False,
focus_metric: str = None,
rgb_overlay_channels: list = None,
rgb_overlay_weights: list = [1.0, 1.0, 1.0],
white_channel: int = None,
expanded=True,
expanded=False,
**kwargs
) -> pd.DataFrame:
......@@ -503,6 +501,12 @@ class RoiSet(object):
assert len(patch.shape) == 4
mask = np.zeros(patch3d.shape[0:2], dtype=bool)
if expanded:
mask[roi.relative_slice[0:2]] = roi.binary_mask
else:
mask = roi.binary_mask
if rescale_clip is not None:
patch = rescale(patch, rescale_clip)
......@@ -520,15 +524,15 @@ class RoiSet(object):
if kwargs.get('draw_mask'):
mci = kwargs.get('mask_channel', 0)
mask = np.zeros(patch.shape[0:2], dtype=bool)
mask[roi.relative_slice[0:2]] = roi.binary_mask
# mask = np.zeros(patch.shape[0:2], dtype=bool)
# mask[roi.relative_slice[0:2]] = roi.binary_mask
for zi in range(0, patch.shape[3]):
patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]
if kwargs.get('draw_contour'):
mci = kwargs.get('contour_channel', 0)
mask = np.zeros(patch.shape[0:2], dtype=bool)
mask[roi.relative_slice[0:2]] = roi.binary_mask
# mask = np.zeros(patch.shape[0:2], dtype=bool)
# mask[roi.relative_slice[0:2]] = roi.binary_mask
for zi in range(0, patch.shape[3]):
patch[:, :, mci, zi] = draw_contours_on_patch(
......
......@@ -329,8 +329,8 @@ class TestIlastikObjectClassification(unittest.TestCase):
def test_classify_patches(self):
raw_patches = self.roiset.get_raw_patches()
patch_masks = self.roiset.get_patch_masks()
raw_patches = self.roiset.get_patches_acc()
patch_masks = self.roiset.get_patch_masks_acc()
res_patches = self.object_classifier.label_instance_class(raw_patches, patch_masks)
self.assertEqual(res_patches.count, self.roiset.count)
res_patches.export_pyxcz(output_path / 'res_patches.tif')
......
......@@ -105,7 +105,8 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
files = roiset.export_patches(
output_path / 'expanded_2d_patches',
draw_bounding_box=True,
expanded=True
expanded=True,
pad_to=256,
)
df = roiset.get_df()
for f in files:
......@@ -113,9 +114,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
la = int(re.search(r'la([\d]+)', str(f)).group(1))
roi_q = df.loc[df.label == la, :]
self.assertEqual(len(roi_q), 1)
roi = roi_q.iloc[0]
h = int(roi.ebb_y1 - roi.ebb_y0)
w = int(roi.ebb_x1 - roi.ebb_x0)
self.assertEqual((256, 256), acc.hw)
def test_make_tight_2d_patches(self):
......@@ -201,10 +199,11 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_raw_patches_are_correct_shape(self):
roiset = self._make_roi_set()
patches = roiset.get_raw_patches()
patches = roiset.get_patches_acc()
np, h, w, nc, nz = patches.shape
self.assertEqual(np, roiset.count)
self.assertEqual(nc, roiset.acc_raw.chroma)
self.assertEqual(nz, 1)
def test_patch_masks_are_correct_shape(self):
roiset = self._make_roi_set()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment