Remove activation in AttentionSqueeze; add balancers; fix bugs RE balancers.

This commit is contained in:
Daniel Povey 2022-11-19 22:05:10 +08:00
parent e9806950f5
commit 4e21db07f6
2 changed files with 36 additions and 45 deletions

View File

@ -122,40 +122,6 @@ def _compute_sign_factor(x: Tensor,
class ActivationScaleBalancerFunction(torch.autograd.Function):
"""
This object is used in class ActivationBalancer when the user specified
min_positive=0, max_positive=1, so there are no constraints on the signs
of the activations and only the absolute value has a constraint.
"""
@staticmethod
def forward(
ctx,
x: Tensor,
sign_factor: Tensor,
scale_factor: Tensor,
channel_dim: int,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
xgt0 = (x > 0)
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
return x
@staticmethod
def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None]:
xgt0, sign_factor, scale_factor = ctx.saved_tensors
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
sign_factor = sign_factor.unsqueeze(-1)
scale_factor = scale_factor.unsqueeze(-1)
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor
return x_grad - neg_delta_grad, None, None, None,
def random_cast_to_half(x: Tensor, def random_cast_to_half(x: Tensor,
@ -465,6 +431,7 @@ class ActivationBalancer(torch.nn.Module):
# loop. # loop.
self.batch_count = 0 self.batch_count = 0
# actually self.num_channels is no longer needed except for an assertion.
self.num_channels = num_channels self.num_channels = num_channels
self.channel_dim = channel_dim self.channel_dim = channel_dim
self.min_positive = min_positive self.min_positive = min_positive
@ -489,6 +456,7 @@ class ActivationBalancer(torch.nn.Module):
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0))) prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
if random.random() < prob: if random.random() < prob:
assert x.shape[self.channel_dim] == self.num_channels
sign_gain_factor = 0.5 sign_gain_factor = 0.5
if self.min_positive != 0.0 or self.max_positive != 1.0: if self.min_positive != 0.0 or self.max_positive != 1.0:
sign_factor = _compute_sign_factor(x, self.channel_dim, sign_factor = _compute_sign_factor(x, self.channel_dim,

View File

@ -1288,25 +1288,46 @@ class AttentionSqueeze(nn.Module):
bottleneck_dim, bottleneck_dim,
bias=False) bias=False)
# the main reason for this balancer is to keep the bottleneck activations in a "reasonable"
# dynamic range, to avoid parameter-size 'drift' where to_bottleneck_proj gets large and from_bottleneck_proj
# gets small or vice versa.
# Caution: this cannot work correctly with an extremeley small batch size, e.g. if # Caution: this cannot work correctly with an extremeley small batch size, e.g. if
# we were training with a single very long audio sequence, or just 2 or 3 sequences # we were training with a single very long audio sequence, or just 2 or 3 sequences
# at a time. We make max_factor small to reduce the harm this could cause # at a time. We make max_factor small to reduce the harm this could cause
# (although when the grads get back past the averaging operation they would # (although when the grads get back past the averaging operation they would
# be quite small and would probably not hurt the rest of the model much.) # be quite small and would probably not hurt the rest of the model much.)
self.balancer = ActivationBalancer( self.bottleneck_balancer = ActivationBalancer(
embed_dim, channel_dim=-1, bottleneck_dim, channel_dim=-1,
min_positive=0.05, max_positive=0.95, min_positive=0.05, max_positive=0.95,
min_abs=0.1, min_abs=0.2,
max_abs=50.0, max_abs=1.0,
max_factor=0.02, max_factor=0.02,
min_prob=0.2, min_prob=0.1,
)
# the next two balancers are only to stop parameter-magnitude 'drift': we have
# too many degrees of freedom for the scales of the various activations.
# Make them run with very low probability, since only a small application of
# these balancers should be enough to stop such "drift"; and, for speed,
# put no limitation on the signs (so: min_positive=0, max_positive=1).
self.scale_balancer = ActivationBalancer(
embed_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.2, max_abs=1.0,
min_prob=0.025,
)
self.activation_balancer = ActivationBalancer(
embed_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.2, max_abs=1.0,
min_prob=0.025,
) )
self.activation = DoubleSwish()
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim) self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
self.out_proj = ScaledLinear(embed_dim, embed_dim, self.out_proj = ScaledLinear(embed_dim, embed_dim,
bias=False, initial_scale=0.1) bias=False, initial_scale=0.05)
self.out_whiten = Whiten(num_groups=1, self.out_whiten = Whiten(num_groups=1,
whitening_limit=10.0, whitening_limit=10.0,
@ -1334,12 +1355,14 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
# (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim) # (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim)
# -> (num_heads, batch_size, seq_len, head_dim) # -> (num_heads, batch_size, seq_len, head_dim)
bottleneck = torch.matmul(attn_weights, bottleneck) bottleneck = torch.matmul(attn_weights, bottleneck)
bottleneck = self.balancer(bottleneck) bottleneck = self.bottleneck_balancer(bottleneck)
bottleneck = self.activation(bottleneck)
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim) bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim) bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
scales = self.from_bottleneck_proj(bottleneck) scales = self.from_bottleneck_proj(bottleneck)
scales = self.scale_balancer(scales)
x = self.activation_balancer(x)
x = self.in_proj(x) x = self.in_proj(x)
x = x * scales x = x * scales
x = self.out_proj(x) x = self.out_proj(x)
@ -1398,7 +1421,7 @@ class NonlinAttentionModule(nn.Module):
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule # deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule
self.deriv_balancer = ActivationBalancer( self.deriv_balancer = ActivationBalancer(
channels, channel_dim=1, channels, channel_dim=-1,
min_positive=0.05, max_positive=1.0, min_positive=0.05, max_positive=1.0,
max_abs=20.0, max_abs=20.0,
) )
@ -1439,8 +1462,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
# now x: (num_heads, batch_size, seq_len, head_dim) # now x: (num_heads, batch_size, seq_len, head_dim)
x = torch.matmul(attn_weights, x) x = torch.matmul(attn_weights, x)
# now x: (num_heads, batch_size, seq_len, head_dim) # now x: (num_heads, batch_size, seq_len, head_dim)
x = self.deriv_balancer(x)
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
x = self.deriv_balancer(x)
x = self.activation(x) x = self.activation(x)
x = self.out_proj(x) x = self.out_proj(x)