Commit 0026647e authored by Constantin Pape's avatar Constantin Pape

Fix issues with training stardist2d with multiple channels

parent 0c9a391b
......@@ -90,18 +90,35 @@ def make_train_val_split(train_images, train_labels, validation_fraction):
return x_train, y_train, x_val, y_val
def _flip_with_channel(x, axis):
n_channels = x.shape[-1]
out = np.zeros_like(x)
for chan_id in range(n_channels):
out[..., chan_id] = np.flip(x[..., chan_id], axis=axis)
return out
# TODO add more augmentations and refactor this so it can be used elsewhere
def random_flips_and_rotations(x, y):
assert y.ndim == 2
assert x.ndim in (2, 3)
is_multichannel = x.ndim == 3
# first, rotate randomly
axes = tuple(range(x.ndim))
permute = np.random.permutation(axes)
x, y = x.transpose(permute), y.transpose(permute)
axes = tuple(range(y.ndim))
permute = tuple(np.random.permutation(axes))
y = y.transpose(permute)
if is_multichannel:
# need to leave last axis (= channel axis) as is in the permutation
permute = permute + (2,)
x = x.transpose(permute)
# second, flip randomly
for ax in axes:
if np.random.rand() > .5:
x, y = np.flip(x, axis=ax), np.flip(y, axis=ax)
y = np.flip(y, axis=ax)
x = _flip_with_channel(x, ax) if is_multichannel else np.flip(x, axis=ax)
return x, y
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment