mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
Use decorrelation in conformer layers also
This commit is contained in:
parent
b9a476c7bb
commit
1669e21c0c
@ -746,7 +746,7 @@ class DecorrelateFunction(torch.autograd.Function):
|
||||
inv_sqrt_diag = (cov.diag() + ctx.eps) ** -0.5
|
||||
norm_cov = cov * (inv_sqrt_diag * inv_sqrt_diag.unsqueeze(-1))
|
||||
|
||||
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels - 1
|
||||
loss = ((norm_cov - norm_cov.diag().diag()) ** 2).sum() / num_channels
|
||||
if random.random() < 0.01:
|
||||
logging.info(f"Decorrelate: loss = {loss}")
|
||||
loss.backward()
|
||||
@ -758,7 +758,9 @@ class DecorrelateFunction(torch.autograd.Function):
|
||||
# `loss ** 0.5` times the magnitude of the original grad.
|
||||
x_grad_new_scale = (x_grad_new ** 2).sum(dim=1)
|
||||
x_grad_old_scale = (x_grad ** 2).sum(dim=1)
|
||||
decorr_loss_scale = ctx.scale
|
||||
|
||||
decorr_loss_scale = ctx.scale * loss.detach().clamp(min=0.0, max=1.0)
|
||||
|
||||
scale = decorr_loss_scale * (x_grad_old_scale / (x_grad_new_scale + 1.0e-10)) ** 0.5
|
||||
x_grad_new = x_grad_new * scale.unsqueeze(-1)
|
||||
|
||||
@ -776,27 +778,56 @@ class Decorrelate(torch.nn.Module):
|
||||
This module does nothing in the forward pass, but in the backward pass, modifies
|
||||
the derivatives in such a way as to encourage the dimensions of its input to become
|
||||
decorrelated.
|
||||
|
||||
Args:
|
||||
num_channels: The number of channels, e.g. 256.
|
||||
apply_prob_decay: The probability with which we apply this each time, in
|
||||
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
|
||||
scale: This number determines the scale of the gradient contribution from
|
||||
this module, relative to whatever the gradient was before;
|
||||
this is applied per frame or pixel, by scaling gradients.
|
||||
eps: An epsilon used to prevent division by zero.
|
||||
beta: A value 0 < beta < 1 that controls decay of covariance stats
|
||||
channel_dim: The dimension of the input corresponding to the channel, e.g.
|
||||
-1, 0, 1, 2.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_channels: int,
|
||||
scale: float = 0.1,
|
||||
apply_prob_decay: int = 1000,
|
||||
eps: float = 1.0e-05,
|
||||
beta: float = 0.95,
|
||||
channel_dim: int = -1):
|
||||
super(Decorrelate, self).__init__()
|
||||
self.scale = scale
|
||||
self.apply_prob_decay = apply_prob_decay
|
||||
self.eps = eps
|
||||
self.beta = beta
|
||||
self.channel_dim = channel_dim
|
||||
|
||||
self.register_buffer('cov', torch.zeros(num_channels, num_channels))
|
||||
# step_buf is a copy of step, included so it will be loaded/saved with
|
||||
# the model.
|
||||
self.register_buffer('step_buf', torch.tensor(0))
|
||||
self.step = 0
|
||||
|
||||
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
super(Decorrelate, self).load_state_dict(*args, **kwargs)
|
||||
self.step = self.step_buf.item()
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if not self.training:
|
||||
return x
|
||||
else:
|
||||
apply_prob = self.apply_prob_decay / (self.step + self.apply_prob_decay)
|
||||
self.step += 1
|
||||
self.step_buf.fill_(self.step)
|
||||
if random.random() > apply_prob:
|
||||
return x
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
ans = DecorrelateFunction.apply(x, self.cov.clone(),
|
||||
self.scale, self.eps, self.beta,
|
||||
@ -807,7 +838,6 @@ class Decorrelate(torch.nn.Module):
|
||||
cov = torch.matmul(x.t(), x)
|
||||
with torch.no_grad():
|
||||
self.cov.mul_(self.beta).add_(cov, alpha=(1-self.beta))
|
||||
self.step += 1
|
||||
|
||||
return ans # ans == x.
|
||||
|
||||
@ -825,9 +855,8 @@ class JoinDropout(torch.nn.Module):
|
||||
|
||||
Args:
|
||||
num_channels: The number of channels, e.g. 256.
|
||||
apply_prob: The probability with which we apply this each time, in
|
||||
training mode. This is to save time (but of course it
|
||||
will tend to make the effect weaker).
|
||||
apply_prob_decay: The probability with which we apply this each time, in
|
||||
training mode, will decay as apply_prob_decay/(apply_prob_decay + step).
|
||||
dropout_rate: This number determines the average dropout probability
|
||||
(it will actually vary across dimensions).
|
||||
eps: An epsilon used to prevent division by zero.
|
||||
@ -925,7 +954,8 @@ class JoinDropout(torch.nn.Module):
|
||||
|
||||
|
||||
def forward(self, bypass: Tensor, x: Tensor) -> Tensor:
|
||||
if not self.training or random.random() > self.apply_prob:
|
||||
apply_prob = self.apply_prob
|
||||
if not self.training or random.random() > apply_prob:
|
||||
return bypass + x
|
||||
else:
|
||||
x = x.transpose(self.channel_dim, -1) # (..., num_channels)
|
||||
@ -1049,6 +1079,28 @@ def _test_gauss_proj_drop():
|
||||
m1.eval()
|
||||
m2.eval()
|
||||
|
||||
def _test_decorrelate():
|
||||
D = 384
|
||||
x = torch.randn(30000, D)
|
||||
# give it a non-unit covariance.
|
||||
m = torch.randn(D, D) * (D ** -0.5)
|
||||
_, S, _ = m.svd()
|
||||
print("M eigs = ", S[::10])
|
||||
x = torch.matmul(x, m)
|
||||
|
||||
|
||||
# check that class Decorrelate does not crash when running..
|
||||
decorrelate = Decorrelate(D)
|
||||
x.requires_grad = True
|
||||
y = decorrelate(x)
|
||||
y.sum().backward()
|
||||
|
||||
decorrelate2 = Decorrelate(D)
|
||||
decorrelate2.load_state_dict(decorrelate.state_dict())
|
||||
assert decorrelate2.step == decorrelate.step
|
||||
|
||||
|
||||
|
||||
def _test_join_dropout():
|
||||
D = 384
|
||||
x = torch.randn(30000, D)
|
||||
@ -1060,13 +1112,6 @@ def _test_join_dropout():
|
||||
x = torch.matmul(x, m)
|
||||
|
||||
|
||||
if True:
|
||||
# check that class Decorrelate does not crash when running..
|
||||
decorrelate = Decorrelate(D)
|
||||
x.requires_grad = True
|
||||
y = decorrelate(x)
|
||||
y.sum().backward()
|
||||
|
||||
for dropout_rate in [0.2, 0.1, 0.01, 0.05]:
|
||||
m1 = torch.nn.Dropout(dropout_rate)
|
||||
m2 = JoinDropout(D, apply_prob=1.0, dropout_rate=dropout_rate)
|
||||
@ -1089,6 +1134,7 @@ if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_decorrelate()
|
||||
_test_join_dropout()
|
||||
_test_gauss_proj_drop()
|
||||
_test_activation_balancer_sign()
|
||||
|
@ -198,10 +198,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0
|
||||
)
|
||||
|
||||
self.dropout_ff_macaron = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||
self.dropout_conv = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||
self.dropout_self_attn = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||
self.dropout_ff = JoinDropout(d_model, apply_prob=0.5, dropout_rate=dropout)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.decorrelate = Decorrelate(d_model, scale=0.02)
|
||||
|
||||
|
||||
def forward(
|
||||
@ -245,7 +243,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
alpha = 1.0
|
||||
|
||||
# macaron style feed forward module
|
||||
src = self.dropout_ff_macaron(src, self.feed_forward_macaron(src))
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
@ -256,16 +254,18 @@ class ConformerEncoderLayer(nn.Module):
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
src = self.dropout_self_attn(src, src_att)
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
src = self.dropout_conv(src, self.conv_module(src))
|
||||
src = src + self.dropout(self.conv_module(src))
|
||||
|
||||
# feed forward module
|
||||
src = self.dropout_ff(src, self.feed_forward(src))
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
src = self.decorrelate(src)
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
@ -1032,7 +1032,6 @@ class Conv2dSubsampling(nn.Module):
|
||||
# itself has learned scale, so the extra degree of freedom is not
|
||||
# needed.
|
||||
self.out_norm = BasicNorm(out_channels, learn_eps=False)
|
||||
self.decorrelate = Decorrelate(out_channels)
|
||||
# constrain median of output to be close to zero.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channel_dim=-1, min_positive=0.45, max_positive=0.55
|
||||
@ -1057,7 +1056,6 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.out_norm(x)
|
||||
x = self.decorrelate(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user