Use decorrelation in conformer layers also

This commit is contained in:
Daniel Povey 2022-06-09 00:05:49 +08:00
parent b9a476c7bb
commit 1669e21c0c
2 changed files with 68 additions and 24 deletions

View File

@ -746,7 +746,7 @@ class DecorrelateFunction(torch.autograd.Function):
inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5 inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1)) norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1 loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels
if random.random() < 0.01: if random.random() < 0.01:
logging.info(f"Decorrelate: loss = {loss}") logging.info(f"Decorrelate: loss = {loss}")
loss.backward() loss.backward()
@ -758,7 +758,9 @@ class DecorrelateFunction(torch.autograd.Function):
# `loss ** 0.5` times the magnitude of the original grad. # `loss ** 0.5` times the magnitude of the original grad.
x_grad_new_scale = (x_grad_new ** 2).sum(dim=1) x_grad_new_scale = (x_grad_new ** 2).sum(dim=1)
x_grad_old_scale = (x_grad ** 2).sum(dim=1) x_grad_old_scale = (x_grad ** 2).sum(dim=1)
decorr_loss_scale = ctx.scale
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5 scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5
x_grad_new = x_grad_new * scale.unsqueeze(-1) x_grad_new = x_grad_new * scale.unsqueeze(-1)
@ -776,27 +778,56 @@ class Decorrelate(torch.nn.Module):
This module does nothing in the forward pass, but in the backward pass, modifies This module does nothing in the forward pass, but in the backward pass, modifies
the derivatives in such a way as to encourage the dimensions of its input to become the derivatives in such a way as to encourage the dimensions of its input to become
decorrelated. decorrelated.
Args:
num_channels: The number of channels, e.g. 256.
apply_prob_decay: The probability with which we apply this each time, in
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
scale: This number determines the scale of the gradient contribution from
this module, relative to whatever the gradient was before;
this is applied per frame or pixel, by scaling gradients.
eps: An epsilon used to prevent division by zero.
beta: A value 0 < beta < 1 that controls decay of covariance stats
channel_dim: The dimension of the input corresponding to the channel, e.g.
-1, 0, 1, 2.
""" """
def __init__(self, def __init__(self,
num_channels: int, num_channels: int,
scale: float = 0.1, scale: float = 0.1,
apply_prob_decay: int = 1000,
eps: float = 1.0e-05, eps: float = 1.0e-05,
beta: float = 0.95, beta: float = 0.95,
channel_dim: int = -1): channel_dim: int = -1):
super(Decorrelate, self).__init__() super(Decorrelate, self).__init__()
self.scale = scale self.scale = scale
self.apply_prob_decay = apply_prob_decay
self.eps = eps self.eps = eps
self.beta = beta self.beta = beta
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.register_buffer('cov', torch.zeros(num_channels, num_channels)) self.register_buffer('cov', torch.zeros(num_channels, num_channels))
# step_buf is a copy of step, included so it will be loaded/saved with
# the model.
self.register_buffer('step_buf', torch.tensor(0))
self.step = 0 self.step = 0
def load_state_dict(self, *args, **kwargs):
super(Decorrelate, self).load_state_dict(*args, **kwargs)
self.step = self.step_buf.item()
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
if not self.training: if not self.training:
return x return x
else: else:
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
self.step += 1
self.step_buf.fill_(self.step)
if random.random() > apply_prob:
return x
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
ans = DecorrelateFunction.apply(x, self.cov.clone(), ans = DecorrelateFunction.apply(x, self.cov.clone(),
self.scale, self.eps, self.beta, self.scale, self.eps, self.beta,
@ -807,7 +838,6 @@ class Decorrelate(torch.nn.Module):
cov = torch.matmul(x.t(), x) cov = torch.matmul(x.t(), x)
with torch.no_grad(): with torch.no_grad():
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta)) self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
self.step += 1
return ans # ans == x. return ans # ans == x.
@ -825,9 +855,8 @@ class JoinDropout(torch.nn.Module):
Args: Args:
num_channels: The number of channels, e.g. 256. num_channels: The number of channels, e.g. 256.
apply_prob: The probability with which we apply this each time, in apply_prob_decay: The probability with which we apply this each time, in
training mode. This is to save time (but of course it training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
will tend to make the effect weaker).
dropout_rate: This number determines the average dropout probability dropout_rate: This number determines the average dropout probability
(it will actually vary across dimensions). (it will actually vary across dimensions).
eps: An epsilon used to prevent division by zero. eps: An epsilon used to prevent division by zero.
@ -925,7 +954,8 @@ class JoinDropout(torch.nn.Module):
def forward(self, bypass: Tensor, x: Tensor) -> Tensor: def forward(self, bypass: Tensor, x: Tensor) -> Tensor:
if not self.training or random.random() > self.apply_prob: apply_prob = self.apply_prob
if not self.training or random.random() > apply_prob:
return bypass + x return bypass + x
else: else:
x = x.transpose(self.channel_dim, -1) # (..., num_channels) x = x.transpose(self.channel_dim, -1) # (..., num_channels)
@ -1049,6 +1079,28 @@ def _test_gauss_proj_drop():
m1.eval() m1.eval()
m2.eval() m2.eval()
def _test_decorrelate():
D = 384
x = torch.randn(30000, D)
# give it a non-unit covariance.
m = torch.randn(D, D) * (D ** -0.5)
_, S, _ = m.svd()
print("M eigs = ", S[::10])
x = torch.matmul(x, m)
# check that class Decorrelate does not crash when running..
decorrelate = Decorrelate(D)
x.requires_grad = True
y = decorrelate(x)
y.sum().backward()
decorrelate2 = Decorrelate(D)
decorrelate2.load_state_dict(decorrelate.state_dict())
assert decorrelate2.step == decorrelate.step
def _test_join_dropout(): def _test_join_dropout():
D = 384 D = 384
x = torch.randn(30000, D) x = torch.randn(30000, D)
@ -1060,13 +1112,6 @@ def _test_join_dropout():
x = torch.matmul(x, m) x = torch.matmul(x, m)
if True:
# check that class Decorrelate does not crash when running..
decorrelate = Decorrelate(D)
x.requires_grad = True
y = decorrelate(x)
y.sum().backward()
for dropout_rate in [0.2, 0.1, 0.01, 0.05]: for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
m1 = torch.nn.Dropout(dropout_rate) m1 = torch.nn.Dropout(dropout_rate)
m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate) m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)
@ -1089,6 +1134,7 @@ if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
_test_decorrelate()
_test_join_dropout() _test_join_dropout()
_test_gauss_proj_drop() _test_gauss_proj_drop()
_test_activation_balancer_sign() _test_activation_balancer_sign()

View File

@ -198,10 +198,8 @@ class ConformerEncoderLayer(nn.Module):
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
) )
self.dropout_ff_macaron = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) self.dropout = nn.Dropout(dropout)
self.dropout_conv = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout) self.decorrelate = Decorrelate(d_model, scale=0.02)
self.dropout_self_attn = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
self.dropout_ff = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
def forward( def forward(
@ -245,7 +243,7 @@ class ConformerEncoderLayer(nn.Module):
alpha = 1.0 alpha = 1.0
# macaron style feed forward module # macaron style feed forward module
src = self.dropout_ff_macaron(src, self.feed_forward_macaron(src)) src = src + self.dropout(self.feed_forward_macaron(src))
# multi-headed self-attention module # multi-headed self-attention module
src_att = self.self_attn( src_att = self.self_attn(
@ -256,16 +254,18 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
)[0] )[0]
src = self.dropout_self_attn(src, src_att) src = src + self.dropout(src_att)
# convolution module # convolution module
src = self.dropout_conv(src, self.conv_module(src)) src = src + self.dropout(self.conv_module(src))
# feed forward module # feed forward module
src = self.dropout_ff(src, self.feed_forward(src)) src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src)) src = self.norm_final(self.balancer(src))
src = self.decorrelate(src)
if alpha != 1.0: if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig src = alpha * src + (1 - alpha) * src_orig
@ -1032,7 +1032,6 @@ class Conv2dSubsampling(nn.Module):
# itself has learned scale, so the extra degree of freedom is not # itself has learned scale, so the extra degree of freedom is not
# needed. # needed.
self.out_norm = BasicNorm(out_channels, learn_eps=False) self.out_norm = BasicNorm(out_channels, learn_eps=False)
self.decorrelate = Decorrelate(out_channels)
# constrain median of output to be close to zero. # constrain median of output to be close to zero.
self.out_balancer = ActivationBalancer( self.out_balancer = ActivationBalancer(
channel_dim=-1, min_positive=0.45, max_positive=0.55 channel_dim=-1, min_positive=0.45, max_positive=0.55
@ -1057,7 +1056,6 @@ class Conv2dSubsampling(nn.Module):
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.out_norm(x) x = self.out_norm(x)
x = self.decorrelate(x)
x = self.out_balancer(x) x = self.out_balancer(x)
return x return x