mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 09:04:19 +00:00
Implement JoinDropout
This commit is contained in:
parent
e7886d49a9
commit
9fb8645168
@ -1029,6 +1029,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
|
@ -713,114 +713,96 @@ class GaussProjDrop(torch.nn.Module):
|
|||||||
x = (x_next * self.rand_scale + x_bypass)
|
x = (x_next * self.rand_scale + x_bypass)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class JoinDropout(torch.nn.Module):
|
||||||
class Decorrelate(torch.nn.Module):
|
|
||||||
"""
|
"""
|
||||||
This module is something similar to dropout; it is a random transformation that
|
This module implements something like:
|
||||||
does nothing in eval mode.
|
y = bypass + dropout(x)
|
||||||
It is designed specifically to encourage the input data to be decorrelated, i.e.
|
but does it in such a way as to encourage x to vary in directions that will tend
|
||||||
to have a diagonal covariance matrix (not necessarily unity).
|
to make the dimensions of y as decorrelated as possible. We do this
|
||||||
|
by putting lots of dropout in directions in the space in which we
|
||||||
|
don't want x to vary (because it will tend to increase correlations between
|
||||||
|
dimensions in the output y).
|
||||||
|
|
||||||
To save time, in training mode we only apply it on randomly selected minibatches.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_channels: The number of channels, e.g. 256.
|
num_channels: The number of channels, e.g. 256.
|
||||||
apply_prob: The probability with which we apply this each time, in
|
apply_prob: The probability with which we apply this each time, in
|
||||||
training mode. This is to save time (but of course it
|
training mode. This is to save time (but of course it
|
||||||
will tend to make the effect weaker).
|
will tend to make the effect weaker).
|
||||||
dropout_rate: This number determines the scale of the random multiplicative
|
dropout_rate: This number determines the average dropout probability
|
||||||
noise, in such a way that the self-correlation and cross-correlation
|
(it will actually vary across dimensions).
|
||||||
statistics match those dropout with the same `dropout_rate`
|
eps: An epsilon used to prevent division by zero.
|
||||||
(assuming we applied the transform, e.g. if apply_prob == 1.0)
|
|
||||||
This number applies when the features are un-correlated.
|
|
||||||
max_dropout_rate: This is an upper limit, for safety, on how aggressive the
|
|
||||||
randomization can be.
|
|
||||||
eps: An epsilon used to prevent division by zero.
|
|
||||||
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
||||||
|
channel_dim: The dimension of the input corresponding to the channel, e.g.
|
||||||
|
-1, 0, 1, 2.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
apply_prob: float = 0.25,
|
apply_prob: float = 0.75,
|
||||||
dropout_rate: float = 0.1,
|
dropout_rate: float = 0.1,
|
||||||
eps: float = 1.0e-04,
|
eps: float = 1.0e-04,
|
||||||
beta: float = 0.95,
|
beta: float = 0.95,
|
||||||
channel_dim: int = -1):
|
channel_dim: int = -1):
|
||||||
super(Decorrelate, self).__init__()
|
super(JoinDropout, self).__init__()
|
||||||
self.apply_prob = apply_prob
|
self.apply_prob = apply_prob
|
||||||
self.dropout_rate = dropout_rate
|
self.dropout_rate = dropout_rate
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
|
|
||||||
#rand_mat = torch.randn(num_channels, num_channels)
|
|
||||||
#U, _, _ = rand_mat.svd()
|
|
||||||
#self.register_buffer('U', U) # a random orthogonal square matrix. will be a buffer.
|
|
||||||
|
|
||||||
self.register_buffer('T1', torch.eye(num_channels))
|
self.register_buffer('T1', torch.eye(num_channels))
|
||||||
self.register_buffer('rand_scales', torch.zeros(num_channels))
|
self.register_buffer('dropout_probs', torch.zeros(num_channels))
|
||||||
self.register_buffer('nonrand_scales', torch.ones(num_channels))
|
self.register_buffer('scales', torch.ones(num_channels))
|
||||||
self.register_buffer('T2', torch.eye(num_channels))
|
self.register_buffer('T2', torch.eye(num_channels))
|
||||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _update_covar_stats(self, y: Tensor) -> None:
|
||||||
def _update_covar_stats(self, x: Tensor) -> None:
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: Tensor of shape (*, num_channels)
|
y: Tensor of shape (*, num_channels), of output.
|
||||||
Updates covariance stats self.cov
|
Updates covariance stats self.cov
|
||||||
"""
|
"""
|
||||||
x = x.detach()
|
y = y.detach()
|
||||||
x = x.reshape(-1, x.shape[-1])
|
y = y.reshape(-1, y.shape[-1])
|
||||||
x = x * (x.shape[0] ** -0.5) # avoid overflow in half precision
|
y = y * (y.shape[0] ** -0.5) # avoid overflow in half precision
|
||||||
cov = torch.matmul(x.t(), x)
|
cov = torch.matmul(y.t(), y)
|
||||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||||
self.step += 1
|
|
||||||
|
|
||||||
def _update_transforms(self):
|
def _update_transforms(self):
|
||||||
|
|
||||||
norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov)
|
norm_cov, inv_sqrt_diag = self._normalize_covar(self.cov)
|
||||||
|
|
||||||
U, S, _ = norm_cov.svd()
|
U, S, _ = norm_cov.svd() # because diag of norm_cov is 1.0, S.mean() == 1.0
|
||||||
|
|
||||||
if random.random() < 0.1:
|
if random.random() < 0.1:
|
||||||
logging.info(f"Decorrelate: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
|
logging.info(f"JoinDropout: max,min eig of normalized cov is: {S.max().item():.2e},{S.min().item():.2e}")
|
||||||
|
|
||||||
|
dropout_probs = (S.sqrt() - 0.99).clamp(min=0)
|
||||||
|
dropout_probs = dropout_probs * (self.dropout_rate / dropout_probs.mean())
|
||||||
|
dropout_probs = dropout_probs.clamp(max=0.5)
|
||||||
|
self.dropout_probs[:] = dropout_probs
|
||||||
|
self.scales[:] = 1.0 / (1 - dropout_probs)
|
||||||
|
|
||||||
|
|
||||||
# row indexes of U correspond to channels, column indexes correspond to
|
# row indexes of U correspond to channels, column indexes correspond to
|
||||||
# singular values: cov = U * diag(S) * U.t() where * is matmul.
|
# singular values: cov = U * diag(S) * U.t() where * is matmul.
|
||||||
S_eps = S + self.eps
|
|
||||||
S_sqrt = S_eps ** 0.5
|
|
||||||
S_inv_sqrt = (S + self.eps) ** -0.5
|
|
||||||
|
|
||||||
# Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is:
|
# Transform T1, which we'll incorporate as torch.matmul(x, self.T1), is:
|
||||||
# (i) multiply by inv_sqrt_diag which makes the covariance have
|
# (i) multiply by inv_sqrt_diag which makes the covariance have
|
||||||
# a unit diagonal.
|
# a unit diagonal.
|
||||||
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
||||||
# (iii) divide by S_sqrt, which makes all dims have unit variance.
|
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U)
|
||||||
self.T1[:] = (inv_sqrt_diag.unsqueeze(-1) * U / S_sqrt)
|
|
||||||
|
|
||||||
# Transform T1, which we'll incorporate as torch.matmul(x, self.TT), is:
|
# Transform T2, which we'll incorporate as torch.matmul(x, self.T2), is:
|
||||||
# (i) multiply by S_sqrt, which restors the variance of different dims,
|
# (i) multiply by U, which un-diagonalizes norm_cov
|
||||||
# (ii) multiply by U, which diagonalizes norm_cov (uncorrelated channels)
|
# (ii) divide by inv_sqrt_diag which makes the covariance have its original
|
||||||
# (iii) divide by inv_sqrt_diag which makes the covariance have its original
|
|
||||||
# diagonal values.
|
# diagonal values.
|
||||||
self.T2[:] = (S_sqrt.unsqueeze(-1) * U.t() / inv_sqrt_diag)
|
self.T2[:] = (U.t() / inv_sqrt_diag)
|
||||||
|
|
||||||
|
|
||||||
# OK, now get rand_scales, which are values between 0 and self.dropout_rate; it says
|
if random.random() < 0.01:
|
||||||
# how much randomness will be in different eigenvalues of norm_cov.
|
|
||||||
# Basically, we want more randomness in directions with eigenvalues more than one,
|
|
||||||
# and none in those with eigenvalues less than one.
|
|
||||||
rand_proportion = (S - 1.0).clamp(min=0.0, max=1.0) * self.dropout_rate
|
|
||||||
|
|
||||||
# rand_proportion is viewed as representing a proportion of the covariance, since
|
|
||||||
# the random and nonrandom components will not be correlated.
|
|
||||||
self.rand_scales[:] = rand_proportion.sqrt()
|
|
||||||
self.nonrand_scales[:] = (1.0 - rand_proportion).sqrt()
|
|
||||||
|
|
||||||
|
|
||||||
if True:
|
|
||||||
d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0],
|
d = torch.matmul(self.T1, self.T2) - torch.eye(self.T1.shape[0],
|
||||||
device=self.T1.device,
|
device=self.T1.device,
|
||||||
dtype=self.T1.dtype)
|
dtype=self.T1.dtype)
|
||||||
@ -839,75 +821,38 @@ class Decorrelate(torch.nn.Module):
|
|||||||
diag = cov.diag()
|
diag = cov.diag()
|
||||||
inv_sqrt_diag = (diag + self.eps) ** -0.5
|
inv_sqrt_diag = (diag + self.eps) ** -0.5
|
||||||
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||||
assert torch.all((cov.diag() - 1.0).abs() < 0.1) # TODO: remove
|
|
||||||
return cov, inv_sqrt_diag
|
return cov, inv_sqrt_diag
|
||||||
|
|
||||||
|
|
||||||
def _randperm_like(self, x: Tensor):
|
|
||||||
"""
|
|
||||||
Returns random permutations of the integers [0,1,..x.shape[-1]-1],
|
|
||||||
with the same shape as x. All dimensions of x other than the last dimension
|
|
||||||
will be treated as batch dimensions.
|
|
||||||
|
|
||||||
Torch's randperm does not support a batch dimension, so we pseudo-randomly simulate it.
|
def forward(self, bypass: Tensor, x: Tensor) -> Tensor:
|
||||||
|
|
||||||
For now, requires x.shape[-1] to be either a power of 2 or 3 times a power of 2, as
|
|
||||||
we normally set channel dims. This is required for some number theoretic stuff.
|
|
||||||
"""
|
|
||||||
n = x.shape[-1]
|
|
||||||
|
|
||||||
assert n & (n-1) == 0 or (n//3 & (n//3 - 1)) == 0
|
|
||||||
|
|
||||||
b = x.numel() // n
|
|
||||||
randint = random.randint(0, 1000)
|
|
||||||
perm = torch.randperm(n, device=x.device)
|
|
||||||
# ensure all elements of batch_rand are coprime to n; this will ensure
|
|
||||||
# that multiplying the permutation by batch_rand and taking modulo
|
|
||||||
# n leaves us with permutations.
|
|
||||||
batch_rand = torch.arange(b, device=x.device) * (randint * 6) + 1
|
|
||||||
batch_rand = batch_rand.unsqueeze(-1)
|
|
||||||
ans = (perm * batch_rand) % n
|
|
||||||
ans = ans.reshape(x.shape)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
if not self.training or random.random() > self.apply_prob:
|
if not self.training or random.random() > self.apply_prob:
|
||||||
return x
|
return bypass + x
|
||||||
else:
|
else:
|
||||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||||
|
bypass = bypass.transpose(self.channel_dim, -1)
|
||||||
|
|
||||||
|
x = torch.matmul(x, self.T1.clone())
|
||||||
|
|
||||||
|
mask = (torch.rand_like(x) > self.dropout_probs)
|
||||||
|
x = (x * mask) * self.scales.clone()
|
||||||
|
x = torch.matmul(x, self.T2.clone())
|
||||||
|
|
||||||
|
y = bypass + x
|
||||||
|
self.step += 1
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
self._update_covar_stats(x)
|
if self.step % 4 == 0 or __name__ == "__main__":
|
||||||
if self.step % 50 == 0 or __name__ == "__main__":
|
self._update_covar_stats(y)
|
||||||
self._update_transforms()
|
if self.step % 40 == 0 or __name__ == "__main__":
|
||||||
|
# note: important that 40 is a multiple of 4
|
||||||
|
self._update_transforms()
|
||||||
|
|
||||||
x = torch.matmul(x, self.T1)
|
y = y.transpose(self.channel_dim, -1)
|
||||||
|
return y
|
||||||
|
|
||||||
x_bypass = x
|
|
||||||
|
|
||||||
if False:
|
|
||||||
# This block, in effect, multiplies x by a random orthogonal matrix,
|
|
||||||
# giving us random noise.
|
|
||||||
perm = self._randperm_like(x)
|
|
||||||
x = torch.gather(x, -1, perm)
|
|
||||||
# self.U will act like a different matrix for every row of x,
|
|
||||||
# because of the random permutation.
|
|
||||||
x = torch.matmul(x, self.U)
|
|
||||||
x_next = torch.empty_like(x)
|
|
||||||
# scatter_ uses perm in opposite way
|
|
||||||
# from gather, inverting it.
|
|
||||||
x_next.scatter_(-1, perm, x)
|
|
||||||
x = x_next
|
|
||||||
|
|
||||||
mask = (torch.rand_like(x) > 0.5)
|
|
||||||
x = x - (x * mask) * 2
|
|
||||||
|
|
||||||
x = (x * self.rand_scales) + (x_bypass * self.nonrand_scales)
|
|
||||||
|
|
||||||
x = torch.matmul(x, self.T2)
|
|
||||||
x = x.transpose(self.channel_dim, -1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1005,8 +950,7 @@ def _test_gauss_proj_drop():
|
|||||||
m1.eval()
|
m1.eval()
|
||||||
m2.eval()
|
m2.eval()
|
||||||
|
|
||||||
def _test_decorrelate():
|
def _test_join_dropout():
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
|
||||||
D = 384
|
D = 384
|
||||||
x = torch.randn(30000, D)
|
x = torch.randn(30000, D)
|
||||||
|
|
||||||
@ -1014,13 +958,14 @@ def _test_decorrelate():
|
|||||||
m = torch.randn(D, D)
|
m = torch.randn(D, D)
|
||||||
x = torch.matmul(x, m)
|
x = torch.matmul(x, m)
|
||||||
|
|
||||||
|
|
||||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||||
m1 = torch.nn.Dropout(dropout_rate)
|
m1 = torch.nn.Dropout(dropout_rate)
|
||||||
m2 = Decorrelate(D, apply_prob=1.0, dropout_rate=dropout_rate)
|
m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)
|
||||||
|
bypass = torch.zeros_like(x)
|
||||||
for mode in ['train', 'eval']:
|
for mode in ['train', 'eval']:
|
||||||
y1 = m1(x)
|
y1 = m1(x)
|
||||||
y2 = m2(x)
|
for _ in range(2):
|
||||||
|
y2 = m2(bypass, x)
|
||||||
xmag = (x*x).mean()
|
xmag = (x*x).mean()
|
||||||
y1mag = (y1*y1).mean()
|
y1mag = (y1*y1).mean()
|
||||||
cross1 = (x*y1).mean()
|
cross1 = (x*y1).mean()
|
||||||
@ -1032,7 +977,10 @@ def _test_decorrelate():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
_test_decorrelate()
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
_test_join_dropout()
|
||||||
_test_gauss_proj_drop()
|
_test_gauss_proj_drop()
|
||||||
_test_activation_balancer_sign()
|
_test_activation_balancer_sign()
|
||||||
_test_activation_balancer_magnitude()
|
_test_activation_balancer_magnitude()
|
||||||
|
@ -29,7 +29,7 @@ from scaling import (
|
|||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledLinear,
|
ScaledLinear,
|
||||||
Decorrelate,
|
JoinDropout,
|
||||||
|
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -198,8 +198,10 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout_ff_macaron = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||||
self.decorrelate = Decorrelate(d_model, apply_prob=0.25, dropout_rate=0.05)
|
self.dropout_conv = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||||
|
self.dropout_self_attn = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||||
|
self.dropout_ff = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -243,7 +245,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
alpha = 1.0
|
alpha = 1.0
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
src = self.dropout_ff_macaron(src, self.feed_forward_macaron(src))
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
@ -254,17 +256,13 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)[0]
|
)[0]
|
||||||
src = src + self.dropout(src_att)
|
src = self.dropout_self_attn(src, src_att)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.dropout(self.conv_module(src))
|
src = self.dropout_conv(src, self.conv_module(src))
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = self.dropout_ff(src, self.feed_forward(src))
|
||||||
|
|
||||||
# encourage dimensions of `src` to be un-correlated with each other, this will
|
|
||||||
# help Adam converge better.
|
|
||||||
src = self.decorrelate(src)
|
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
@ -1326,6 +1324,9 @@ def _test_random_combine_main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user