mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove whitening in SelfAttention module.
This commit is contained in:
parent
19683aa516
commit
1a2632d0a2
@ -1211,15 +1211,6 @@ class SelfAttention(nn.Module):
|
||||
num_heads * value_head_dim,
|
||||
bias=True)
|
||||
|
||||
# attempt to make the output of `in_proj` uncorrelated within each head
|
||||
# and all heads having roughly the same magnitude. the hope is to
|
||||
# improve learning dynamics; this loses no power as there is no constraint
|
||||
# on the condition number of out_proj.
|
||||
self.whiten_values = Whiten(num_groups=num_heads,
|
||||
whitening_limit=2.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
|
||||
self.out_proj = ScaledLinear(num_heads * value_head_dim,
|
||||
embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
@ -1252,7 +1243,6 @@ class SelfAttention(nn.Module):
|
||||
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
||||
x = self.whiten_values(x) # does nothing in the forward pass.
|
||||
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
||||
value_head_dim = x.shape[-1]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user