Add output balancer to NonlinAttentionModule.

This commit is contained in:
Daniel Povey 2022-11-22 14:29:07 +08:00
parent 71f118e725
commit fe1793e288

View File

@ -1415,12 +1415,12 @@ class NonlinAttentionModule(nn.Module):
self.in_proj = nn.Linear(channels, 2 * channels, bias=True) self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
# balancer goes after the glu mechanism. # balancer that goes after the glu mechanism.
self.balancer = ActivationBalancer( self.balancer = ActivationBalancer(
channels, channel_dim=-1, channels, channel_dim=-1,
min_positive=0.2, max_positive=0.8, min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=10.0, min_abs=0.2, max_abs=10.0,
min_prob=0.1, min_prob=0.05,
) )
# give it a high limit, because it is quite high-dimensional and is # give it a high limit, because it is quite high-dimensional and is
# a projection of a lower-dimensional embedding. # a projection of a lower-dimensional embedding.
@ -1435,6 +1435,18 @@ class NonlinAttentionModule(nn.Module):
bias=True, bias=True,
initial_scale=0.05) initial_scale=0.05)
# put quite strict limits on the min_positive and max_positive at the output,
# because we noticed that poorly-trained instances of NonlinAttentionModule seem
# to have a larger mean-offset at the output for some reason.
self.out_balancer = ActivationBalancer(
channels, channel_dim=-1,
min_positive=0.4, max_positive=0.6,
min_abs=0.005, max_abs=1.0,
min_prob=0.05,
)
def forward(self, def forward(self,
x: Tensor, x: Tensor,
attn_weights: Tensor, attn_weights: Tensor,
@ -1472,7 +1484,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
x = self.activation(x) # diagnostics only, it's the identity. x = self.activation(x) # diagnostics only, it's the identity.
x = self.out_proj(x) x = self.out_proj(x)
x = self.out_balancer(x)
return x return x