mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add output balancer to NonlinAttentionModule.
This commit is contained in:
parent
71f118e725
commit
fe1793e288
@ -1415,12 +1415,12 @@ class NonlinAttentionModule(nn.Module):
|
||||
|
||||
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
|
||||
|
||||
# balancer goes after the glu mechanism.
|
||||
# balancer that goes after the glu mechanism.
|
||||
self.balancer = ActivationBalancer(
|
||||
channels, channel_dim=-1,
|
||||
min_positive=0.2, max_positive=0.8,
|
||||
min_abs=0.2, max_abs=10.0,
|
||||
min_prob=0.1,
|
||||
min_prob=0.05,
|
||||
)
|
||||
# give it a high limit, because it is quite high-dimensional and is
|
||||
# a projection of a lower-dimensional embedding.
|
||||
@ -1435,6 +1435,18 @@ class NonlinAttentionModule(nn.Module):
|
||||
bias=True,
|
||||
initial_scale=0.05)
|
||||
|
||||
|
||||
# put quite strict limits on the min_positive and max_positive at the output,
|
||||
# because we noticed that poorly-trained instances of NonlinAttentionModule seem
|
||||
# to have a larger mean-offset at the output for some reason.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channels, channel_dim=-1,
|
||||
min_positive=0.4, max_positive=0.6,
|
||||
min_abs=0.005, max_abs=1.0,
|
||||
min_prob=0.05,
|
||||
)
|
||||
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
attn_weights: Tensor,
|
||||
@ -1472,7 +1484,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
x = self.activation(x) # diagnostics only, it's the identity.
|
||||
x = self.out_proj(x)
|
||||
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user