mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reduce whitening limit to 10 and move it to the beginning.
This commit is contained in:
parent
584f5bf88c
commit
56efdcda49
@ -1423,7 +1423,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
min_prob=0.1,
|
||||
)
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
whitening_limit=20.0,
|
||||
whitening_limit=10.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
@ -1444,7 +1444,9 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
Returns:
|
||||
a Tensor with the same shape as x
|
||||
"""
|
||||
v, s = self.in_proj(x).chunk(2, dim=-1)
|
||||
x = self.in_proj(x)
|
||||
x = self.whiten(x)
|
||||
v, s = x.chunk(2, dim=-1)
|
||||
|
||||
if self.training and random.random() < 0.02:
|
||||
# prevent the inputs to the sigmoid from getting very large (this is
|
||||
@ -1455,7 +1457,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
# GLU mechanism
|
||||
x = s.sigmoid() * v
|
||||
x = self.balancer(x)
|
||||
x = self.whiten(x)
|
||||
|
||||
(seq_len, batch_size, embed_dim) = x.shape
|
||||
num_heads = attn_weights.shape[0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user