mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove whitening in SelfAttention module.
This commit is contained in:
parent
19683aa516
commit
1a2632d0a2
@ -1211,15 +1211,6 @@ class SelfAttention(nn.Module):
|
|||||||
num_heads * value_head_dim,
|
num_heads * value_head_dim,
|
||||||
bias=True)
|
bias=True)
|
||||||
|
|
||||||
# attempt to make the output of `in_proj` uncorrelated within each head
|
|
||||||
# and all heads having roughly the same magnitude. the hope is to
|
|
||||||
# improve learning dynamics; this loses no power as there is no constraint
|
|
||||||
# on the condition number of out_proj.
|
|
||||||
self.whiten_values = Whiten(num_groups=num_heads,
|
|
||||||
whitening_limit=2.0,
|
|
||||||
prob=(0.025, 0.25),
|
|
||||||
grad_scale=0.025)
|
|
||||||
|
|
||||||
self.out_proj = ScaledLinear(num_heads * value_head_dim,
|
self.out_proj = ScaledLinear(num_heads * value_head_dim,
|
||||||
embed_dim, bias=True,
|
embed_dim, bias=True,
|
||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
@ -1252,7 +1243,6 @@ class SelfAttention(nn.Module):
|
|||||||
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||||
|
|
||||||
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
||||||
x = self.whiten_values(x) # does nothing in the forward pass.
|
|
||||||
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||||
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
||||||
value_head_dim = x.shape[-1]
|
value_head_dim = x.shape[-1]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user