mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
Implement JoinDropout
This commit is contained in:
parent
e7886d49a9
commit
9fb8645168
@ -1029,6 +1029,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
batch_size = 5
|
||||
|
@ -713,114 +713,96 @@ class GaussProjDrop(torch.nn.Module):
|
||||
x = (x_next * self.rand_scale + x_bypass)
|
||||
return x
|
||||
|
||||
|
||||
class Decorrelate(torch.nn.Module):
|
||||
class JoinDropout(torch.nn.Module):
|
||||
"""
|
||||
This module is something similar to dropout; it is a random transformation that
|
||||
does nothing in eval mode.
|
||||
It is designed specifically to encourage the input data to be decorrelated, i.e.
|
||||
to have a diagonal covariance matrix (not necessarily unity).
|
||||
This module implements something like:
|
||||
y = bypass + dropout(x)
|
||||
but does it in such a way as to encourage x to vary in directions that will tend
|
||||
to make the dimensions of y as decorrelated as possible. We do this
|
||||
by putting lots of dropout in directions in the space in which we
|
||||
don't want x to vary (because it will tend to increase correlations between
|
||||
dimensions in the output y).
|
||||
|
||||
To save time, in training mode we only apply it on randomly selected minibatches.
|
||||
|
||||
Args:
|
||||
num_channels: The number of channels, e.g. 256.
|
||||
apply_prob: The probability with which we apply this each time, in
|
||||
training mode. This is to save time (but of course it
|
||||
will tend to make the effect weaker).
|
||||
dropout_rate: This number determines the scale of the random multiplicative
|
||||
noise, in such a way that the self-correlation and cross-correlation
|
||||
statistics match those dropout with the same `dropout_rate`
|
||||
(assuming we applied the transform, e.g. if apply_prob == 1.0)
|
||||
This number applies when the features are un-correlated.
|
||||
max_dropout_rate: This is an upper limit, for safety, on how aggressive the
|
||||
randomization can be.
|
||||
eps: An epsilon used to prevent division by zero.
|
||||
dropout_rate: This number determines the average dropout probability
|
||||
(it will actually vary across dimensions).
|
||||
eps: An epsilon used to prevent division by zero.
|
||||
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
||||
channel_dim: The dimension of the input corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
apply_prob: float = 0.25,
|
||||
apply_prob: float = 0.75,
|
||||
dropout_rate: float = 0.1,
|
||||
eps: float = 1.0e-04,
|
||||
beta: float = 0.95,
|
||||
channel_dim: int = -1):
|
||||
super(Decorrelate, self).__init__()
|
||||
super(JoinDropout, self).__init__()
|
||||
self.apply_prob = apply_prob
|
||||
self.dropout_rate = dropout_rate
|
||||
self.channel_dim = channel_dim
|
||||
self.eps = eps
|
||||
self.beta = beta
|
||||
|
||||
#rand_mat = torch.randn(num_channels, num_channels)
|
||||
#U, _, _ = rand_mat.svd()
|
||||
#self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
||||
|
||||
self.register_buffer('T1', torch.eye(num_channels))
|
||||
self.register_buffer('rand_scales', torch.zeros(num_channels))
|
||||
self.register_buffer('nonrand_scales', torch.ones(num_channels))
|
||||
self.register_buffer('dropout_probs', torch.zeros(num_channels))
|
||||
self.register_buffer('scales', torch.ones(num_channels))
|
||||
self.register_buffer('T2', torch.eye(num_channels))
|
||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
||||
self.step = 0
|
||||
|
||||
|
||||
|
||||
def _update_covar_stats(self, x: Tensor) -> None:
|
||||
def _update_covar_stats(self, y: Tensor) -> None:
|
||||
"""
|
||||
Args:
|
||||
x: Tensor of shape (*, num_channels)
|
||||
y: Tensor of shape (*, num_channels), of output.
|
||||
Updates covariance stats self.cov
|
||||
"""
|
||||
x = x.detach()
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision
|
||||
cov = torch.matmul(x.t(), x)
|
||||
y = y.detach()
|
||||
y = y.reshape(-1, y.shape[-1])
|
||||
y = y * (y.shape[0] ** -0.5) # avoid overflow in half precision
|
||||
cov = torch.matmul(y.t(), y)
|
||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||
self.step += 1
|
||||
|
||||
def _update_transforms(self):
|
||||
|
||||
norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov)
|
||||
|
||||
U, S, _ = norm_cov.svd()
|
||||
U, S, _ = norm_cov.svd() # because diag of norm_cov is 1.0, S.mean() == 1.0
|
||||
|
||||
if random.random() < 0.1:
|
||||
logging.info(f"Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
|
||||
logging.info(f"JoinDropout: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
|
||||
|
||||
dropout_probs = (S.sqrt() - 0.99).clamp(min=0)
|
||||
dropout_probs = dropout_probs * (self.dropout_rate / dropout_probs.mean())
|
||||
dropout_probs = dropout_probs.clamp(max=0.5)
|
||||
self.dropout_probs[:] = dropout_probs
|
||||
self.scales[:] = 1.0 / (1 - dropout_probs)
|
||||
|
||||
|
||||
# row indexes of U correspond to channels, column indexes correspond to
|
||||
# singular values: cov = U * diag(S) * U.t() where * is matmul.
|
||||
S_eps = S + self.eps
|
||||
S_sqrt = S_eps ** 0.5
|
||||
S_inv_sqrt = (S + self.eps) ** -0.5
|
||||
|
||||
|
||||
# Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is:
|
||||
# (i) multiply by inv_sqrt_diag which makes the covariance have
|
||||
# a unit diagonal.
|
||||
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
||||
# (iii) divide by S_sqrt, which makes all dims have unit variance.
|
||||
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U / S_sqrt)
|
||||
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U)
|
||||
|
||||
# Transform T1, which we'll incorporate as torch.matmul(x, self.TT), is:
|
||||
# (i) multiply by S_sqrt, which restors the variance of different dims,
|
||||
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
||||
# (iii) divide by inv_sqrt_diag which makes the covariance have its original
|
||||
# Transform T2, which we'll incorporate as torch.matmul(x, self.T2), is:
|
||||
# (i) multiply by U, which un-diagonalizes norm_cov
|
||||
# (ii) divide by inv_sqrt_diag which makes the covariance have its original
|
||||
# diagonal values.
|
||||
self.T2[:] = (S_sqrt.unsqueeze(-1) * U.t() / inv_sqrt_diag)
|
||||
self.T2[:] = (U.t() / inv_sqrt_diag)
|
||||
|
||||
|
||||
# OK, now get rand_scales, which are values between 0 and self.dropout_rate; it says
|
||||
# how much randomness will be in different eigenvalues of norm_cov.
|
||||
# Basically, we want more randomness in directions with eigenvalues more than one,
|
||||
# and none in those with eigenvalues less than one.
|
||||
rand_proportion = (S - 1.0).clamp(min=0.0, max=1.0) * self.dropout_rate
|
||||
|
||||
# rand_proportion is viewed as representing a proportion of the covariance, since
|
||||
# the random and nonrandom components will not be correlated.
|
||||
self.rand_scales[:] = rand_proportion.sqrt()
|
||||
self.nonrand_scales[:] = (1.0 - rand_proportion).sqrt()
|
||||
|
||||
|
||||
if True:
|
||||
if random.random() < 0.01:
|
||||
d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0],
|
||||
device=self.T1.device,
|
||||
dtype=self.T1.dtype)
|
||||
@ -839,75 +821,38 @@ class Decorrelate(torch.nn.Module):
|
||||
diag = cov.diag()
|
||||
inv_sqrt_diag = (diag + self.eps) ** -0.5
|
||||
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||
assert torch.all((cov.diag() - 1.0).abs() < 0.1) # TODO: remove
|
||||
return cov, inv_sqrt_diag
|
||||
|
||||
|
||||
def _randperm_like(self, x: Tensor):
|
||||
"""
|
||||
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
|
||||
with the same shape as x. All dimensions of x other than the last dimension
|
||||
will be treated as batch dimensions.
|
||||
|
||||
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
|
||||
|
||||
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
|
||||
we normally set channel dims. This is required for some number theoretic stuff.
|
||||
"""
|
||||
n = x.shape[-1]
|
||||
|
||||
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
|
||||
|
||||
b = x.numel() // n
|
||||
randint = random.randint(0, 1000)
|
||||
perm = torch.randperm(n, device=x.device)
|
||||
# ensure all elements of batch_rand are coprime to n; this will ensure
|
||||
# that multiplying the permutation by batch_rand and taking modulo
|
||||
# n leaves us with permutations.
|
||||
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
|
||||
batch_rand = batch_rand.unsqueeze(-1)
|
||||
ans = (perm * batch_rand) % n
|
||||
ans = ans.reshape(x.shape)
|
||||
return ans
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, bypass: Tensor, x: Tensor) -> Tensor:
|
||||
if not self.training or random.random() > self.apply_prob:
|
||||
return x
|
||||
return bypass + x
|
||||
else:
|
||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||
bypass = bypass.transpose(self.channel_dim, -1)
|
||||
|
||||
x = torch.matmul(x, self.T1.clone())
|
||||
|
||||
mask = (torch.rand_like(x) > self.dropout_probs)
|
||||
x = (x * mask) * self.scales.clone()
|
||||
x = torch.matmul(x, self.T2.clone())
|
||||
|
||||
y = bypass + x
|
||||
self.step += 1
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
self._update_covar_stats(x)
|
||||
if self.step % 50 == 0 or __name__ == "__main__":
|
||||
self._update_transforms()
|
||||
if self.step % 4 == 0 or __name__ == "__main__":
|
||||
self._update_covar_stats(y)
|
||||
if self.step % 40 == 0 or __name__ == "__main__":
|
||||
# note: important that 40 is a multiple of 4
|
||||
self._update_transforms()
|
||||
|
||||
x = torch.matmul(x, self.T1)
|
||||
y = y.transpose(self.channel_dim, -1)
|
||||
return y
|
||||
|
||||
x_bypass = x
|
||||
|
||||
if False:
|
||||
# This block, in effect, multiplies x by a random orthogonal matrix,
|
||||
# giving us random noise.
|
||||
perm = self._randperm_like(x)
|
||||
x = torch.gather(x, -1, perm)
|
||||
# self.U will act like a different matrix for every row of x,
|
||||
# because of the random permutation.
|
||||
x = torch.matmul(x, self.U)
|
||||
x_next = torch.empty_like(x)
|
||||
# scatter_ uses perm in opposite way
|
||||
# from gather, inverting it.
|
||||
x_next.scatter_(-1, perm, x)
|
||||
x = x_next
|
||||
|
||||
mask = (torch.rand_like(x) > 0.5)
|
||||
x = x - (x * mask) * 2
|
||||
|
||||
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)
|
||||
|
||||
x = torch.matmul(x, self.T2)
|
||||
x = x.transpose(self.channel_dim, -1)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@ -1005,8 +950,7 @@ def _test_gauss_proj_drop():
|
||||
m1.eval()
|
||||
m2.eval()
|
||||
|
||||
def _test_decorrelate():
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
def _test_join_dropout():
|
||||
D = 384
|
||||
x = torch.randn(30000, D)
|
||||
|
||||
@ -1014,13 +958,14 @@ def _test_decorrelate():
|
||||
m = torch.randn(D, D)
|
||||
x = torch.matmul(x, m)
|
||||
|
||||
|
||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||
m1 = torch.nn.Dropout(dropout_rate)
|
||||
m2 = Decorrelate(D, apply_prob=1.0, dropout_rate=dropout_rate)
|
||||
m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)
|
||||
bypass = torch.zeros_like(x)
|
||||
for mode in ['train', 'eval']:
|
||||
y1 = m1(x)
|
||||
y2 = m2(x)
|
||||
for _ in range(2):
|
||||
y2 = m2(bypass, x)
|
||||
xmag = (x*x).mean()
|
||||
y1mag = (y1*y1).mean()
|
||||
cross1 = (x*y1).mean()
|
||||
@ -1032,7 +977,10 @@ def _test_decorrelate():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_decorrelate()
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_join_dropout()
|
||||
_test_gauss_proj_drop()
|
||||
_test_activation_balancer_sign()
|
||||
_test_activation_balancer_magnitude()
|
||||
|
@ -29,7 +29,7 @@ from scaling import (
|
||||
ScaledConv1d,
|
||||
ScaledConv2d,
|
||||
ScaledLinear,
|
||||
Decorrelate,
|
||||
JoinDropout,
|
||||
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
@ -198,8 +198,10 @@ class ConformerEncoderLayer(nn.Module):
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.05)
|
||||
self.dropout_ff_macaron = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||
self.dropout_conv = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||
self.dropout_self_attn = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||
self.dropout_ff = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||
|
||||
|
||||
def forward(
|
||||
@ -243,7 +245,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
alpha = 1.0
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
src = self.dropout_ff_macaron(src, self.feed_forward_macaron(src))
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
@ -254,17 +256,13 @@ class ConformerEncoderLayer(nn.Module):
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
src = src + self.dropout(src_att)
|
||||
src = self.dropout_self_attn(src, src_att)
|
||||
|
||||
# convolution module
|
||||
src = src + self.dropout(self.conv_module(src))
|
||||
src = self.dropout_conv(src, self.conv_module(src))
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
# encourage dimensions of `src` to be un-correlated with each other, this will
|
||||
# help Adam converge better.
|
||||
src = self.decorrelate(src)
|
||||
src = self.dropout_ff(src, self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
@ -1326,6 +1324,9 @@ def _test_random_combine_main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
batch_size = 5
|
||||
|
Loading…
x
Reference in New Issue
Block a user