mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add a second whitening to the NonlinAttentionModule, after the aggregation.
This commit is contained in:
parent
35f0ea0015
commit
45069175d9
@ -1437,10 +1437,14 @@ class NonlinAttentionModule(nn.Module):
|
||||
bias=True,
|
||||
initial_scale=0.05)
|
||||
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
self.whiten1 = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.01, 0.1),
|
||||
grad_scale=0.01)
|
||||
self.whiten2 = Whiten(num_groups=1,
|
||||
whitening_limit=_whitening_schedule(7.5),
|
||||
prob=(0.01, 0.1),
|
||||
grad_scale=0.01)
|
||||
|
||||
|
||||
|
||||
@ -1465,7 +1469,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
# very small probability to save time).
|
||||
s = penalize_abs_values_gt(s, limit=20.0, penalty=1.0e-04)
|
||||
|
||||
v = self.whiten(v)
|
||||
v = self.whiten1(v)
|
||||
# GLU mechanism
|
||||
x = s.sigmoid() * v
|
||||
x = self.balancer(x)
|
||||
@ -1481,6 +1485,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||
|
||||
x = self.activation(x) # diagnostics only, it's the identity.
|
||||
x = self.whiten2(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user