Use decorrelation in conformer layers also

This commit is contained in:
Daniel Povey 2022-06-09 00:05:49 +08:00
parent b9a476c7bb
commit 1669e21c0c
2 changed files with 68 additions and 24 deletions

View File

@ -746,7 +746,7 @@ class DecorrelateFunction(torch.autograd.Function):
inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels
if random.random() < 0.01:
logging.info(f"Decorrelate: loss = {loss}")
loss.backward()
@ -758,7 +758,9 @@ class DecorrelateFunction(torch.autograd.Function):
# `loss ** 0.5` times the magnitude of the original grad.
x_grad_new_scale = (x_grad_new ** 2).sum(dim=1)
x_grad_old_scale = (x_grad ** 2).sum(dim=1)
decorr_loss_scale = ctx.scale
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5
x_grad_new = x_grad_new * scale.unsqueeze(-1)
@ -776,27 +778,56 @@ class Decorrelate(torch.nn.Module):
This module does nothing in the forward pass, but in the backward pass, modifies
the derivatives in such a way as to encourage the dimensions of its input to become
decorrelated.
Args:
num_channels: The number of channels, e.g. 256.
apply_prob_decay: The probability with which we apply this each time, in
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
scale: This number determines the scale of the gradient contribution from
this module, relative to whatever the gradient was before;
this is applied per frame or pixel, by scaling gradients.
eps: An epsilon used to prevent division by zero.
beta: A value 0 < beta < 1 that controls decay of covariance stats
channel_dim: The dimension of the input corresponding to the channel, e.g.
-1, 0, 1, 2.
"""
def __init__(self,
num_channels: int,
scale: float = 0.1,
apply_prob_decay: int = 1000,
eps: float = 1.0e-05,
beta: float = 0.95,
channel_dim: int = -1):
super(Decorrelate, self).__init__()
self.scale = scale
self.apply_prob_decay = apply_prob_decay
self.eps = eps
self.beta = beta
self.channel_dim = channel_dim
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
# step_buf is a copy of step, included so it will be loaded/saved with
# the model.
self.register_buffer('step_buf', torch.tensor(0))
self.step = 0
def load_state_dict(self, *args, **kwargs):
super(Decorrelate, self).load_state_dict(*args, **kwargs)
self.step = self.step_buf.item()
def forward(self, x: Tensor) -> Tensor:
if not self.training:
return x
else:
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
self.step += 1
self.step_buf.fill_(self.step)
if random.random() > apply_prob:
return x
with torch.cuda.amp.autocast(enabled=False):
ans = DecorrelateFunction.apply(x, self.cov.clone(),
self.scale, self.eps, self.beta,
@ -807,7 +838,6 @@ class Decorrelate(torch.nn.Module):
cov = torch.matmul(x.t(), x)
with torch.no_grad():
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1
return ans # ans == x.
@ -825,9 +855,8 @@ class JoinDropout(torch.nn.Module):
Args:
num_channels: The number of channels, e.g. 256.
apply_prob: The probability with which we apply this each time, in
training mode. This is to save time (but of course it
will tend to make the effect weaker).
apply_prob_decay: The probability with which we apply this each time, in
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
dropout_rate: This number determines the average dropout probability
(it will actually vary across dimensions).
eps: An epsilon used to prevent division by zero.
@ -925,7 +954,8 @@ class JoinDropout(torch.nn.Module):
def forward(self, bypass: Tensor, x: Tensor) -> Tensor:
if not self.training or random.random() > self.apply_prob:
apply_prob = self.apply_prob
if not self.training or random.random() > apply_prob:
return bypass + x
else:
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
@ -1049,6 +1079,28 @@ def _test_gauss_proj_drop():
m1.eval()
m2.eval()
def _test_decorrelate():
D = 384
x = torch.randn(30000, D)
# give it a non-unit covariance.
m = torch.randn(D, D) * (D ** -0.5)
_, S, _ = m.svd()
print("M eigs = ", S[::10])
x = torch.matmul(x, m)
# check that class Decorrelate does not crash when running..
decorrelate = Decorrelate(D)
x.requires_grad = True
y = decorrelate(x)
y.sum().backward()
decorrelate2 = Decorrelate(D)
decorrelate2.load_state_dict(decorrelate.state_dict())
assert decorrelate2.step == decorrelate.step
def _test_join_dropout():
D = 384
x = torch.randn(30000, D)
@ -1060,13 +1112,6 @@ def _test_join_dropout():
x = torch.matmul(x, m)
if True:
# check that class Decorrelate does not crash when running..
decorrelate = Decorrelate(D)
x.requires_grad = True
y = decorrelate(x)
y.sum().backward()
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate)
m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)
@ -1089,6 +1134,7 @@ if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_decorrelate()
_test_join_dropout()
_test_gauss_proj_drop()
_test_activation_balancer_sign()

View File

@ -198,10 +198,8 @@ class ConformerEncoderLayer(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
)
self.dropout_ff_macaron = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout_conv = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout_self_attn = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout_ff = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout = nn.Dropout(dropout)
self.decorrelate = Decorrelate(d_model, scale=0.02)
def forward(
@ -245,7 +243,7 @@ class ConformerEncoderLayer(nn.Module):
alpha = 1.0
# macaron style feed forward module
src = self.dropout_ff_macaron(src, self.feed_forward_macaron(src))
src = src + self.dropout(self.feed_forward_macaron(src))
# multi-headed self-attention module
src_att = self.self_attn(
@ -256,16 +254,18 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = self.dropout_self_attn(src, src_att)
src = src + self.dropout(src_att)
# convolution module
src = self.dropout_conv(src, self.conv_module(src))
src = src + self.dropout(self.conv_module(src))
# feed forward module
src = self.dropout_ff(src, self.feed_forward(src))
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
src = self.decorrelate(src)
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
@ -1032,7 +1032,6 @@ class Conv2dSubsampling(nn.Module):
# itself has learned scale, so the extra degree of freedom is not
# needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False)
self.decorrelate = Decorrelate(out_channels)
# constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55
@ -1057,7 +1056,6 @@ class Conv2dSubsampling(nn.Module):
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x)
x = self.decorrelate(x)
x = self.out_balancer(x)
return x