mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
From 460->461, revert change about balancing output of attention_squeeze module.
This commit is contained in:
parent
fe51eea397
commit
d95571eacf
@ -1343,17 +1343,10 @@ class AttentionSqueeze(nn.Module):
|
|||||||
self.out_proj = ScaledLinear(embed_dim, embed_dim,
|
self.out_proj = ScaledLinear(embed_dim, embed_dim,
|
||||||
bias=False, initial_scale=0.05)
|
bias=False, initial_scale=0.05)
|
||||||
|
|
||||||
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=10.0,
|
whitening_limit=10.0,
|
||||||
prob=(0.01, 0.1),
|
prob=(0.01, 0.1),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
self.out_balancer = ActivationBalancer(
|
|
||||||
embed_dim, channel_dim=-1,
|
|
||||||
min_positive=0.45, max_positive=0.55,
|
|
||||||
min_abs=0.005, max_abs=2.0,
|
|
||||||
min_prob=0.05,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -1388,7 +1381,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
x = self.activation(x) # Identity only. For diagnostics.
|
x = self.activation(x) # Identity only. For diagnostics.
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
x = self.out_whiten(x)
|
x = self.out_whiten(x)
|
||||||
x = self.out_balancer(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user