Changes and bug-fixes RE balancers; restore activation in AttentionSqueeze, remove in NonlinAttention.

This commit is contained in:
Daniel Povey 2022-11-21 14:29:36 +08:00
parent 9fe6add587
commit 836c72dd36

View File

@ -32,8 +32,7 @@ from scaling import (
ScaledConv1d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Whiten,
Identity,
_diag,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt,
softmax,
ScheduledFloat,
@ -212,7 +211,7 @@ class Zipformer(EncoderInterface):
for i in range(len(z)):
if i <= 1 or z[i-1] <= z[i]:
skip_layers.append(None)
skip_modules.append(nn.Identity())
skip_modules.append(Identity())
else:
# TEMP
for j in range(i-2, -1, -1):
@ -1300,11 +1299,12 @@ class AttentionSqueeze(nn.Module):
self.bottleneck_balancer = ActivationBalancer(
bottleneck_dim, channel_dim=-1,
min_positive=0.05, max_positive=0.95,
min_abs=0.2,
max_abs=1.0,
min_abs=0.05,
max_abs=2.0,
max_factor=0.02,
min_prob=0.1,
)
self.activation = DoubleSwish() # in bottleneck
# the next two balancers are only to stop parameter-magnitude 'drift': we have
# too many degrees of freedom for the scales of the various activations.
@ -1313,15 +1313,15 @@ class AttentionSqueeze(nn.Module):
# put no limitation on the signs (so: min_positive=0, max_positive=1).
self.scale_balancer = ActivationBalancer(
embed_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0,
min_prob=0.025,
min_prob=0.05,
)
self.activation_balancer = ActivationBalancer(
embed_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_positive=0.2, max_positive=0.8,
min_abs=0.2, max_abs=1.0,
min_prob=0.025,
min_prob=0.05,
)
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
@ -1356,14 +1356,14 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
# -> (num_heads, batch_size, seq_len, head_dim)
bottleneck = torch.matmul(attn_weights, bottleneck)
bottleneck = self.bottleneck_balancer(bottleneck)
bottleneck = self.activation(bottleneck)
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
scales = self.from_bottleneck_proj(bottleneck)
scales = self.scale_balancer(scales)
x = self.activation_balancer(x)
x = self.in_proj(x)
x = self.activation_balancer(x)
scales = self.scale_balancer(scales)
x = x * scales
x = self.out_proj(x)
x = self.out_whiten(x)
@ -1399,7 +1399,8 @@ class FeedforwardModule(nn.Module):
class NonlinAttentionModule(nn.Module):
"""This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
from the attention module) in place of actual convolution.
from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
one after the attention mechanism.
Args:
channels (int): The number of channels of conv layers.
@ -1410,24 +1411,9 @@ class NonlinAttentionModule(nn.Module):
) -> None:
super().__init__()
# to_scale and to_value are analogous to pointwise_conv1 in ConvolutionModule
# we make them separate because we need an extra degree of freedom for the
# scale, as the attention weights are constrained to sum to one so cannot
# provide the degree of freedom for the scale of the features before
# self.activation().
self.to_scale = nn.Linear(channels, channels, bias=True)
self.to_value = nn.Linear(channels, channels, bias=True)
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionModule
self.deriv_balancer = ActivationBalancer(
channels, channel_dim=-1,
min_positive=0.05, max_positive=1.0,
max_abs=20.0,
)
self.activation = DoubleSwish()
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
self.activation = Identity() # for diagnostics.
self.out_proj = ScaledLinear(channels, channels,
bias=True,
initial_scale=0.05)
@ -1443,8 +1429,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
Returns:
a Tensor with the same shape as x
"""
s = self.to_scale(x)
v = self.to_value(x)
v, s = self.in_proj(x).chunk(2, dim=-1)
if self.training and random.random() < 0.02:
# prevent the inputs to the sigmoid from getting very large (this is
# hopefully quite a rare phenomenon, so we are giving this path a
@ -1463,8 +1449,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
x = torch.matmul(attn_weights, x)
# now x: (num_heads, batch_size, seq_len, head_dim)
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
x = self.deriv_balancer(x)
x = self.activation(x)
x = self.activation(x) # diagnostics only, it's the identity.
x = self.out_proj(x)
return x