Implement JoinDropout

This commit is contained in:
Daniel Povey 2022-06-08 16:11:48 +08:00
parent e7886d49a9
commit 9fb8645168
3 changed files with 80 additions and 129 deletions

View File

@ -1029,6 +1029,8 @@ class Conv2dSubsampling(nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
feature_dim = 50 feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4) c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5 batch_size = 5

View File

@ -713,114 +713,96 @@ class GaussProjDrop(torch.nn.Module):
x = (x_next * self.rand_scale + x_bypass) x = (x_next * self.rand_scale + x_bypass)
return x return x
class JoinDropout(torch.nn.Module):
class Decorrelate(torch.nn.Module):
""" """
This module is something similar to dropout; it is a random transformation that This module implements something like:
does nothing in eval mode. y = bypass + dropout(x)
It is designed specifically to encourage the input data to be decorrelated, i.e. but does it in such a way as to encourage x to vary in directions that will tend
to have a diagonal covariance matrix (not necessarily unity). to make the dimensions of y as decorrelated as possible. We do this
by putting lots of dropout in directions in the space in which we
don't want x to vary (because it will tend to increase correlations between
dimensions in the output y).
To save time, in training mode we only apply it on randomly selected minibatches.
Args: Args:
num_channels: The number of channels, e.g. 256. num_channels: The number of channels, e.g. 256.
apply_prob: The probability with which we apply this each time, in apply_prob: The probability with which we apply this each time, in
training mode. This is to save time (but of course it training mode. This is to save time (but of course it
will tend to make the effect weaker). will tend to make the effect weaker).
dropout_rate: This number determines the scale of the random multiplicative dropout_rate: This number determines the average dropout probability
noise, in such a way that the self-correlation and cross-correlation (it will actually vary across dimensions).
statistics match those dropout with the same `dropout_rate` eps: An epsilon used to prevent division by zero.
(assuming we applied the transform, e.g. if apply_prob == 1.0)
This number applies when the features are un-correlated.
max_dropout_rate: This is an upper limit, for safety, on how aggressive the
randomization can be.
eps: An epsilon used to prevent division by zero.
beta: A value 0 < beta < 1 that controls decay of covariance stats beta: A value 0 < beta < 1 that controls decay of covariance stats
channel_dim: The dimension of the input corresponding to the channel, e.g.
-1, 0, 1, 2.
""" """
def __init__(self, def __init__(self,
num_channels: int, num_channels: int,
apply_prob: float = 0.25, apply_prob: float = 0.75,
dropout_rate: float = 0.1, dropout_rate: float = 0.1,
eps: float = 1.0e-04, eps: float = 1.0e-04,
beta: float = 0.95, beta: float = 0.95,
channel_dim: int = -1): channel_dim: int = -1):
super(Decorrelate, self).__init__() super(JoinDropout, self).__init__()
self.apply_prob = apply_prob self.apply_prob = apply_prob
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.eps = eps self.eps = eps
self.beta = beta self.beta = beta
#rand_mat = torch.randn(num_channels, num_channels)
#U, _, _ = rand_mat.svd()
#self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
self.register_buffer('T1', torch.eye(num_channels)) self.register_buffer('T1', torch.eye(num_channels))
self.register_buffer('rand_scales', torch.zeros(num_channels)) self.register_buffer('dropout_probs', torch.zeros(num_channels))
self.register_buffer('nonrand_scales', torch.ones(num_channels)) self.register_buffer('scales', torch.ones(num_channels))
self.register_buffer('T2', torch.eye(num_channels)) self.register_buffer('T2', torch.eye(num_channels))
self.register_buffer('cov', torch.zeros(num_channels, num_channels)) self.register_buffer('cov', torch.zeros(num_channels, num_channels))
self.step = 0 self.step = 0
def _update_covar_stats(self, y: Tensor) -> None:
def _update_covar_stats(self, x: Tensor) -> None:
""" """
Args: Args:
x: Tensor of shape (*, num_channels) y: Tensor of shape (*, num_channels), of output.
Updates covariance stats self.cov Updates covariance stats self.cov
""" """
x = x.detach() y = y.detach()
x = x.reshape(-1, x.shape[-1]) y = y.reshape(-1, y.shape[-1])
x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision y = y * (y.shape[0] ** -0.5) # avoid overflow in half precision
cov = torch.matmul(x.t(), x) cov = torch.matmul(y.t(), y)
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1
def _update_transforms(self): def _update_transforms(self):
norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov) norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov)
U, S, _ = norm_cov.svd() U, S, _ = norm_cov.svd() # because diag of norm_cov is 1.0, S.mean() == 1.0
if random.random() < 0.1: if random.random() < 0.1:
logging.info(f"Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}") logging.info(f"JoinDropout: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
dropout_probs = (S.sqrt() - 0.99).clamp(min=0)
dropout_probs = dropout_probs * (self.dropout_rate / dropout_probs.mean())
dropout_probs = dropout_probs.clamp(max=0.5)
self.dropout_probs[:] = dropout_probs
self.scales[:] = 1.0 / (1 - dropout_probs)
# row indexes of U correspond to channels, column indexes correspond to # row indexes of U correspond to channels, column indexes correspond to
# singular values: cov = U * diag(S) * U.t() where * is matmul. # singular values: cov = U * diag(S) * U.t() where * is matmul.
S_eps = S + self.eps
S_sqrt = S_eps ** 0.5
S_inv_sqrt = (S + self.eps) ** -0.5
# Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is: # Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is:
# (i) multiply by inv_sqrt_diag which makes the covariance have # (i) multiply by inv_sqrt_diag which makes the covariance have
# a unit diagonal. # a unit diagonal.
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels) # (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
# (iii) divide by S_sqrt, which makes all dims have unit variance. self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U)
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U / S_sqrt)
# Transform T1, which we'll incorporate as torch.matmul(x, self.TT), is: # Transform T2, which we'll incorporate as torch.matmul(x, self.T2), is:
# (i) multiply by S_sqrt, which restors the variance of different dims, # (i) multiply by U, which un-diagonalizes norm_cov
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels) # (ii) divide by inv_sqrt_diag which makes the covariance have its original
# (iii) divide by inv_sqrt_diag which makes the covariance have its original
# diagonal values. # diagonal values.
self.T2[:] = (S_sqrt.unsqueeze(-1) * U.t() / inv_sqrt_diag) self.T2[:] = (U.t() / inv_sqrt_diag)
# OK, now get rand_scales, which are values between 0 and self.dropout_rate; it says if random.random() < 0.01:
# how much randomness will be in different eigenvalues of norm_cov.
# Basically, we want more randomness in directions with eigenvalues more than one,
# and none in those with eigenvalues less than one.
rand_proportion = (S - 1.0).clamp(min=0.0, max=1.0) * self.dropout_rate
# rand_proportion is viewed as representing a proportion of the covariance, since
# the random and nonrandom components will not be correlated.
self.rand_scales[:] = rand_proportion.sqrt()
self.nonrand_scales[:] = (1.0 - rand_proportion).sqrt()
if True:
d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0], d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0],
device=self.T1.device, device=self.T1.device,
dtype=self.T1.dtype) dtype=self.T1.dtype)
@ -839,75 +821,38 @@ class Decorrelate(torch.nn.Module):
diag = cov.diag() diag = cov.diag()
inv_sqrt_diag = (diag + self.eps) ** -0.5 inv_sqrt_diag = (diag + self.eps) ** -0.5
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
assert torch.all((cov.diag() - 1.0).abs() < 0.1) # TODO: remove
return cov, inv_sqrt_diag return cov, inv_sqrt_diag
def _randperm_like(self, x: Tensor):
"""
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
with the same shape as x. All dimensions of x other than the last dimension
will be treated as batch dimensions.
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it. def forward(self, bypass: Tensor, x: Tensor) -> Tensor:
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
we normally set channel dims. This is required for some number theoretic stuff.
"""
n = x.shape[-1]
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
b = x.numel() // n
randint = random.randint(0, 1000)
perm = torch.randperm(n, device=x.device)
# ensure all elements of batch_rand are coprime to n; this will ensure
# that multiplying the permutation by batch_rand and taking modulo
# n leaves us with permutations.
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
batch_rand = batch_rand.unsqueeze(-1)
ans = (perm * batch_rand) % n
ans = ans.reshape(x.shape)
return ans
def forward(self, x: Tensor) -> Tensor:
if not self.training or random.random() > self.apply_prob: if not self.training or random.random() > self.apply_prob:
return x return bypass + x
else: else:
x = x.transpose(self.channel_dim, -1) # (..., num_channels) x = x.transpose(self.channel_dim, -1) # (..., num_channels)
bypass = bypass.transpose(self.channel_dim, -1)
x = torch.matmul(x, self.T1.clone())
mask = (torch.rand_like(x) > self.dropout_probs)
x = (x * mask) * self.scales.clone()
x = torch.matmul(x, self.T2.clone())
y = bypass + x
self.step += 1
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
self._update_covar_stats(x) if self.step % 4 == 0 or __name__ == "__main__":
if self.step % 50 == 0 or __name__ == "__main__": self._update_covar_stats(y)
self._update_transforms() if self.step % 40 == 0 or __name__ == "__main__":
# note: important that 40 is a multiple of 4
self._update_transforms()
x = torch.matmul(x, self.T1) y = y.transpose(self.channel_dim, -1)
return y
x_bypass = x
if False:
# This block, in effect, multiplies x by a random orthogonal matrix,
# giving us random noise.
perm = self._randperm_like(x)
x = torch.gather(x, -1, perm)
# self.U will act like a different matrix for every row of x,
# because of the random permutation.
x = torch.matmul(x, self.U)
x_next = torch.empty_like(x)
# scatter_ uses perm in opposite way
# from gather, inverting it.
x_next.scatter_(-1, perm, x)
x = x_next
mask = (torch.rand_like(x) > 0.5)
x = x - (x * mask) * 2
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)
x = torch.matmul(x, self.T2)
x = x.transpose(self.channel_dim, -1)
return x
@ -1005,8 +950,7 @@ def _test_gauss_proj_drop():
m1.eval() m1.eval()
m2.eval() m2.eval()
def _test_decorrelate(): def _test_join_dropout():
logging.getLogger().setLevel(logging.INFO)
D = 384 D = 384
x = torch.randn(30000, D) x = torch.randn(30000, D)
@ -1014,13 +958,14 @@ def _test_decorrelate():
m = torch.randn(D, D) m = torch.randn(D, D)
x = torch.matmul(x, m) x = torch.matmul(x, m)
for dropout_rate in [0.2, 0.1, 0.01, 0.05]: for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate) m1 = torch.nn.Dropout(dropout_rate)
m2 = Decorrelate(D, apply_prob=1.0, dropout_rate=dropout_rate) m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)
bypass = torch.zeros_like(x)
for mode in ['train', 'eval']: for mode in ['train', 'eval']:
y1 = m1(x) y1 = m1(x)
y2 = m2(x) for _ in range(2):
y2 = m2(bypass, x)
xmag = (x*x).mean() xmag = (x*x).mean()
y1mag = (y1*y1).mean() y1mag = (y1*y1).mean()
cross1 = (x*y1).mean() cross1 = (x*y1).mean()
@ -1032,7 +977,10 @@ def _test_decorrelate():
if __name__ == "__main__": if __name__ == "__main__":
_test_decorrelate() logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_join_dropout()
_test_gauss_proj_drop() _test_gauss_proj_drop()
_test_activation_balancer_sign() _test_activation_balancer_sign()
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()

View File

@ -29,7 +29,7 @@ from scaling import (
ScaledConv1d, ScaledConv1d,
ScaledConv2d, ScaledConv2d,
ScaledLinear, ScaledLinear,
Decorrelate, JoinDropout,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -198,8 +198,10 @@ class ConformerEncoderLayer(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
) )
self.dropout = torch.nn.Dropout(dropout) self.dropout_ff_macaron = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.05) self.dropout_conv = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout_self_attn = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout_ff = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
def forward( def forward(
@ -243,7 +245,7 @@ class ConformerEncoderLayer(nn.Module):
alpha = 1.0 alpha = 1.0
# macaron style feed forward module # macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src)) src = self.dropout_ff_macaron(src, self.feed_forward_macaron(src))
# multi-headed self-attention module # multi-headed self-attention module
src_att = self.self_attn( src_att = self.self_attn(
@ -254,17 +256,13 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
)[0] )[0]
src = src + self.dropout(src_att) src = self.dropout_self_attn(src, src_att)
# convolution module # convolution module
src = src + self.dropout(self.conv_module(src)) src = self.dropout_conv(src, self.conv_module(src))
# feed forward module # feed forward module
src = src + self.dropout(self.feed_forward(src)) src = self.dropout_ff(src, self.feed_forward(src))
# encourage dimensions of `src` to be un-correlated with each other, this will
# help Adam converge better.
src = self.decorrelate(src)
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
@ -1326,6 +1324,9 @@ def _test_random_combine_main():
if __name__ == "__main__": if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
feature_dim = 50 feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4) c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5 batch_size = 5