Implement JoinDropout

This commit is contained in:
Daniel Povey 2022-06-08 16:11:48 +08:00
parent e7886d49a9
commit 9fb8645168
3 changed files with 80 additions and 129 deletions

View File

@ -1029,6 +1029,8 @@ class Conv2dSubsampling(nn.Module):
if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5

View File

@ -713,114 +713,96 @@ class GaussProjDrop(torch.nn.Module):
x = (x_next * self.rand_scale + x_bypass)
return x
class Decorrelate(torch.nn.Module):
class JoinDropout(torch.nn.Module):
"""
This module is something similar to dropout; it is a random transformation that
does nothing in eval mode.
It is designed specifically to encourage the input data to be decorrelated, i.e.
to have a diagonal covariance matrix (not necessarily unity).
This module implements something like:
y = bypass + dropout(x)
but does it in such a way as to encourage x to vary in directions that will tend
to make the dimensions of y as decorrelated as possible. We do this
by putting lots of dropout in directions in the space in which we
don't want x to vary (because it will tend to increase correlations between
dimensions in the output y).
To save time, in training mode we only apply it on randomly selected minibatches.
Args:
num_channels: The number of channels, e.g. 256.
apply_prob: The probability with which we apply this each time, in
training mode. This is to save time (but of course it
will tend to make the effect weaker).
dropout_rate: This number determines the scale of the random multiplicative
noise, in such a way that the self-correlation and cross-correlation
statistics match those dropout with the same `dropout_rate`
(assuming we applied the transform, e.g. if apply_prob == 1.0)
This number applies when the features are un-correlated.
max_dropout_rate: This is an upper limit, for safety, on how aggressive the
randomization can be.
eps: An epsilon used to prevent division by zero.
dropout_rate: This number determines the average dropout probability
(it will actually vary across dimensions).
eps: An epsilon used to prevent division by zero.
beta: A value 0 < beta < 1 that controls decay of covariance stats
channel_dim: The dimension of the input corresponding to the channel, e.g.
-1, 0, 1, 2.
"""
def __init__(self,
num_channels: int,
apply_prob: float = 0.25,
apply_prob: float = 0.75,
dropout_rate: float = 0.1,
eps: float = 1.0e-04,
beta: float = 0.95,
channel_dim: int = -1):
super(Decorrelate, self).__init__()
super(JoinDropout, self).__init__()
self.apply_prob = apply_prob
self.dropout_rate = dropout_rate
self.channel_dim = channel_dim
self.eps = eps
self.beta = beta
#rand_mat = torch.randn(num_channels, num_channels)
#U, _, _ = rand_mat.svd()
#self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
self.register_buffer('T1', torch.eye(num_channels))
self.register_buffer('rand_scales', torch.zeros(num_channels))
self.register_buffer('nonrand_scales', torch.ones(num_channels))
self.register_buffer('dropout_probs', torch.zeros(num_channels))
self.register_buffer('scales', torch.ones(num_channels))
self.register_buffer('T2', torch.eye(num_channels))
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
self.step = 0
def _update_covar_stats(self, x: Tensor) -> None:
def _update_covar_stats(self, y: Tensor) -> None:
"""
Args:
x: Tensor of shape (*, num_channels)
y: Tensor of shape (*, num_channels), of output.
Updates covariance stats self.cov
"""
x = x.detach()
x = x.reshape(-1, x.shape[-1])
x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision
cov = torch.matmul(x.t(), x)
y = y.detach()
y = y.reshape(-1, y.shape[-1])
y = y * (y.shape[0] ** -0.5) # avoid overflow in half precision
cov = torch.matmul(y.t(), y)
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1
def _update_transforms(self):
norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov)
U, S, _ = norm_cov.svd()
U, S, _ = norm_cov.svd() # because diag of norm_cov is 1.0, S.mean() == 1.0
if random.random() < 0.1:
logging.info(f"Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
logging.info(f"JoinDropout: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
dropout_probs = (S.sqrt() - 0.99).clamp(min=0)
dropout_probs = dropout_probs * (self.dropout_rate / dropout_probs.mean())
dropout_probs = dropout_probs.clamp(max=0.5)
self.dropout_probs[:] = dropout_probs
self.scales[:] = 1.0 / (1 - dropout_probs)
# row indexes of U correspond to channels, column indexes correspond to
# singular values: cov = U * diag(S) * U.t() where * is matmul.
S_eps = S + self.eps
S_sqrt = S_eps ** 0.5
S_inv_sqrt = (S + self.eps) ** -0.5
# Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is:
# (i) multiply by inv_sqrt_diag which makes the covariance have
# a unit diagonal.
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
# (iii) divide by S_sqrt, which makes all dims have unit variance.
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U / S_sqrt)
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U)
# Transform T1, which we'll incorporate as torch.matmul(x, self.TT), is:
# (i) multiply by S_sqrt, which restors the variance of different dims,
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
# (iii) divide by inv_sqrt_diag which makes the covariance have its original
# Transform T2, which we'll incorporate as torch.matmul(x, self.T2), is:
# (i) multiply by U, which un-diagonalizes norm_cov
# (ii) divide by inv_sqrt_diag which makes the covariance have its original
# diagonal values.
self.T2[:] = (S_sqrt.unsqueeze(-1) * U.t() / inv_sqrt_diag)
self.T2[:] = (U.t() / inv_sqrt_diag)
# OK, now get rand_scales, which are values between 0 and self.dropout_rate; it says
# how much randomness will be in different eigenvalues of norm_cov.
# Basically, we want more randomness in directions with eigenvalues more than one,
# and none in those with eigenvalues less than one.
rand_proportion = (S - 1.0).clamp(min=0.0, max=1.0) * self.dropout_rate
# rand_proportion is viewed as representing a proportion of the covariance, since
# the random and nonrandom components will not be correlated.
self.rand_scales[:] = rand_proportion.sqrt()
self.nonrand_scales[:] = (1.0 - rand_proportion).sqrt()
if True:
if random.random() < 0.01:
d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0],
device=self.T1.device,
dtype=self.T1.dtype)
@ -839,75 +821,38 @@ class Decorrelate(torch.nn.Module):
diag = cov.diag()
inv_sqrt_diag = (diag + self.eps) ** -0.5
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
assert torch.all((cov.diag() - 1.0).abs() < 0.1) # TODO: remove
return cov, inv_sqrt_diag
def _randperm_like(self, x: Tensor):
"""
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
with the same shape as x. All dimensions of x other than the last dimension
will be treated as batch dimensions.
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
we normally set channel dims. This is required for some number theoretic stuff.
"""
n = x.shape[-1]
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
b = x.numel() // n
randint = random.randint(0, 1000)
perm = torch.randperm(n, device=x.device)
# ensure all elements of batch_rand are coprime to n; this will ensure
# that multiplying the permutation by batch_rand and taking modulo
# n leaves us with permutations.
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
batch_rand = batch_rand.unsqueeze(-1)
ans = (perm * batch_rand) % n
ans = ans.reshape(x.shape)
return ans
def forward(self, x: Tensor) -> Tensor:
def forward(self, bypass: Tensor, x: Tensor) -> Tensor:
if not self.training or random.random() > self.apply_prob:
return x
return bypass + x
else:
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
bypass = bypass.transpose(self.channel_dim, -1)
x = torch.matmul(x, self.T1.clone())
mask = (torch.rand_like(x) > self.dropout_probs)
x = (x * mask) * self.scales.clone()
x = torch.matmul(x, self.T2.clone())
y = bypass + x
self.step += 1
with torch.cuda.amp.autocast(enabled=False):
self._update_covar_stats(x)
if self.step % 50 == 0 or __name__ == "__main__":
self._update_transforms()
if self.step % 4 == 0 or __name__ == "__main__":
self._update_covar_stats(y)
if self.step % 40 == 0 or __name__ == "__main__":
# note: important that 40 is a multiple of 4
self._update_transforms()
x = torch.matmul(x, self.T1)
y = y.transpose(self.channel_dim, -1)
return y
x_bypass = x
if False:
# This block, in effect, multiplies x by a random orthogonal matrix,
# giving us random noise.
perm = self._randperm_like(x)
x = torch.gather(x, -1, perm)
# self.U will act like a different matrix for every row of x,
# because of the random permutation.
x = torch.matmul(x, self.U)
x_next = torch.empty_like(x)
# scatter_ uses perm in opposite way
# from gather, inverting it.
x_next.scatter_(-1, perm, x)
x = x_next
mask = (torch.rand_like(x) > 0.5)
x = x - (x * mask) * 2
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)
x = torch.matmul(x, self.T2)
x = x.transpose(self.channel_dim, -1)
return x
@ -1005,8 +950,7 @@ def _test_gauss_proj_drop():
m1.eval()
m2.eval()
def _test_decorrelate():
logging.getLogger().setLevel(logging.INFO)
def _test_join_dropout():
D = 384
x = torch.randn(30000, D)
@ -1014,13 +958,14 @@ def _test_decorrelate():
m = torch.randn(D, D)
x = torch.matmul(x, m)
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate)
m2 = Decorrelate(D, apply_prob=1.0, dropout_rate=dropout_rate)
m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)
bypass = torch.zeros_like(x)
for mode in ['train', 'eval']:
y1 = m1(x)
y2 = m2(x)
for _ in range(2):
y2 = m2(bypass, x)
xmag = (x*x).mean()
y1mag = (y1*y1).mean()
cross1 = (x*y1).mean()
@ -1032,7 +977,10 @@ def _test_decorrelate():
if __name__ == "__main__":
_test_decorrelate()
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_join_dropout()
_test_gauss_proj_drop()
_test_activation_balancer_sign()
_test_activation_balancer_magnitude()

View File

@ -29,7 +29,7 @@ from scaling import (
ScaledConv1d,
ScaledConv2d,
ScaledLinear,
Decorrelate,
JoinDropout,
)
from torch import Tensor, nn
@ -198,8 +198,10 @@ class ConformerEncoderLayer(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout = torch.nn.Dropout(dropout)
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.05)
self.dropout_ff_macaron = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout_conv = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout_self_attn = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout_ff = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
def forward(
@ -243,7 +245,7 @@ class ConformerEncoderLayer(nn.Module):
alpha = 1.0
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
src = self.dropout_ff_macaron(src, self.feed_forward_macaron(src))
# multi-headed self-attention module
src_att = self.self_attn(
@ -254,17 +256,13 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = src + self.dropout(src_att)
src = self.dropout_self_attn(src, src_att)
# convolution module
src = src + self.dropout(self.conv_module(src))
src = self.dropout_conv(src, self.conv_module(src))
# feed forward module
src = src + self.dropout(self.feed_forward(src))
# encourage dimensions of `src` to be un-correlated with each other, this will
# help Adam converge better.
src = self.decorrelate(src)
src = self.dropout_ff(src, self.feed_forward(src))
src = self.norm_final(self.balancer(src))
@ -1326,6 +1324,9 @@ def _test_random_combine_main():
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5