Change balancer to whitener for ff module; tighther min/max-pos limit on NonlinAttentionModule; whitener->balancer for AttentionSqueeze.

This commit is contained in:
Daniel Povey 2022-11-22 15:42:41 +08:00
parent 26916f41e7
commit 7acdaea085

View File

@ -1306,7 +1306,8 @@ class AttentionSqueeze(nn.Module):
max_factor=0.02,
min_prob=0.1,
)
self.activation = DoubleSwish() # in bottleneck
self.bottleneck_activation = DoubleSwish() # in bottleneck
self.activation = Identity() # for diagnostics
# the next two balancers are only to stop parameter-magnitude 'drift': we have
# too many degrees of freedom for the scales of the various activations.
@ -1331,11 +1332,12 @@ class AttentionSqueeze(nn.Module):
self.out_proj = ScaledLinear(embed_dim, embed_dim,
bias=False, initial_scale=0.05)
self.out_whiten = Whiten(num_groups=1,
whitening_limit=10.0,
prob=(0.025, 0.25),
grad_scale=0.01)
self.out_balancer = ActivationBalancer(
embed_dim, channel_dim=-1,
min_positive=0.45, max_positive=0.55,
min_abs=0.005, max_abs=2.0,
min_prob=0.05,
)
def forward(self,
x: Tensor,
@ -1358,7 +1360,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
# -> (num_heads, batch_size, seq_len, head_dim)
bottleneck = torch.matmul(attn_weights, bottleneck)
bottleneck = self.bottleneck_balancer(bottleneck)
bottleneck = self.activation(bottleneck)
bottleneck = self.bottleneck_activation(bottleneck)
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
scales = self.from_bottleneck_proj(bottleneck)
@ -1367,8 +1369,9 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
x = self.activation_balancer(x)
scales = self.scale_balancer(scales)
x = x * scales
x = self.activation(x) # Identity only. For diagnostics.
x = self.out_proj(x)
x = self.out_whiten(x)
x = self.out_balancer(x)
return x
@ -1388,11 +1391,10 @@ class FeedforwardModule(nn.Module):
self.dropout = nn.Dropout(dropout)
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
initial_scale=0.01)
self.out_balancer = ActivationBalancer(embed_dim,
min_positive=0.4, max_positive=0.6,
min_abs=0.01, max_abs=5.0,
channel_dim=-1, min_prob=0.1)
self.out_whitener = Whiten(num_groups=1,
whitening_limit=10.0,
prob=(0.025, 0.25),
grad_scale=0.01)
def forward(self,
x: Tensor):
@ -1401,7 +1403,7 @@ class FeedforwardModule(nn.Module):
x = self.activation(x)
x = self.dropout(x)
x = self.out_proj(x)
x = self.out_balancer(x)
x = self.out_whitener(x)
return x
@ -1447,7 +1449,7 @@ class NonlinAttentionModule(nn.Module):
# to have a larger mean-offset at the output for some reason.
self.out_balancer = ActivationBalancer(
channels, channel_dim=-1,
min_positive=0.4, max_positive=0.6,
min_positive=0.45, max_positive=0.55,
min_abs=0.005, max_abs=1.0,
min_prob=0.05,
)