mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove activation in AttentionSqueeze; add balancers; fix bugs RE balancers.
This commit is contained in:
parent
e9806950f5
commit
4e21db07f6
@ -122,40 +122,6 @@ def _compute_sign_factor(x: Tensor,
|
||||
|
||||
|
||||
|
||||
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
||||
"""
|
||||
This object is used in class ActivationBalancer when the user specified
|
||||
min_positive=0, max_positive=1, so there are no constraints on the signs
|
||||
of the activations and only the absolute value has a constraint.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x: Tensor,
|
||||
sign_factor: Tensor,
|
||||
scale_factor: Tensor,
|
||||
channel_dim: int,
|
||||
) -> Tensor:
|
||||
if channel_dim < 0:
|
||||
channel_dim += x.ndim
|
||||
ctx.channel_dim = channel_dim
|
||||
xgt0 = (x > 0)
|
||||
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
||||
return x
|
||||
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx, x_grad: Tensor
|
||||
) -> Tuple[Tensor, None, None, None]:
|
||||
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
||||
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
||||
sign_factor = sign_factor.unsqueeze(-1)
|
||||
scale_factor = scale_factor.unsqueeze(-1)
|
||||
|
||||
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
||||
neg_delta_grad = x_grad.abs() * factor
|
||||
return x_grad - neg_delta_grad, None, None, None,
|
||||
|
||||
|
||||
def random_cast_to_half(x: Tensor,
|
||||
@ -465,6 +431,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
# loop.
|
||||
self.batch_count = 0
|
||||
|
||||
# actually self.num_channels is no longer needed except for an assertion.
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.min_positive = min_positive
|
||||
@ -489,6 +456,7 @@ class ActivationBalancer(torch.nn.Module):
|
||||
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||
|
||||
if random.random() < prob:
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
sign_gain_factor = 0.5
|
||||
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
||||
sign_factor = _compute_sign_factor(x, self.channel_dim,
|
||||
|
||||
@ -1288,25 +1288,46 @@ class AttentionSqueeze(nn.Module):
|
||||
bottleneck_dim,
|
||||
bias=False)
|
||||
|
||||
|
||||
# the main reason for this balancer is to keep the bottleneck activations in a "reasonable"
|
||||
# dynamic range, to avoid parameter-size 'drift' where to_bottleneck_proj gets large and from_bottleneck_proj
|
||||
# gets small or vice versa.
|
||||
# Caution: this cannot work correctly with an extremeley small batch size, e.g. if
|
||||
# we were training with a single very long audio sequence, or just 2 or 3 sequences
|
||||
# at a time. We make max_factor small to reduce the harm this could cause
|
||||
# (although when the grads get back past the averaging operation they would
|
||||
# be quite small and would probably not hurt the rest of the model much.)
|
||||
self.balancer = ActivationBalancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
self.bottleneck_balancer = ActivationBalancer(
|
||||
bottleneck_dim, channel_dim=-1,
|
||||
min_positive=0.05, max_positive=0.95,
|
||||
min_abs=0.1,
|
||||
max_abs=50.0,
|
||||
min_abs=0.2,
|
||||
max_abs=1.0,
|
||||
max_factor=0.02,
|
||||
min_prob=0.2,
|
||||
min_prob=0.1,
|
||||
)
|
||||
|
||||
# the next two balancers are only to stop parameter-magnitude 'drift': we have
|
||||
# too many degrees of freedom for the scales of the various activations.
|
||||
# Make them run with very low probability, since only a small application of
|
||||
# these balancers should be enough to stop such "drift"; and, for speed,
|
||||
# put no limitation on the signs (so: min_positive=0, max_positive=1).
|
||||
self.scale_balancer = ActivationBalancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
min_prob=0.025,
|
||||
)
|
||||
self.activation_balancer = ActivationBalancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
min_abs=0.2, max_abs=1.0,
|
||||
min_prob=0.025,
|
||||
)
|
||||
self.activation = DoubleSwish()
|
||||
|
||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
|
||||
|
||||
self.out_proj = ScaledLinear(embed_dim, embed_dim,
|
||||
bias=False, initial_scale=0.1)
|
||||
bias=False, initial_scale=0.05)
|
||||
|
||||
self.out_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=10.0,
|
||||
@ -1334,12 +1355,14 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
# (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim)
|
||||
# -> (num_heads, batch_size, seq_len, head_dim)
|
||||
bottleneck = torch.matmul(attn_weights, bottleneck)
|
||||
bottleneck = self.balancer(bottleneck)
|
||||
bottleneck = self.activation(bottleneck)
|
||||
bottleneck = self.bottleneck_balancer(bottleneck)
|
||||
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
|
||||
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
|
||||
scales = self.from_bottleneck_proj(bottleneck)
|
||||
|
||||
scales = self.scale_balancer(scales)
|
||||
x = self.activation_balancer(x)
|
||||
|
||||
x = self.in_proj(x)
|
||||
x = x * scales
|
||||
x = self.out_proj(x)
|
||||
@ -1398,7 +1421,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
|
||||
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule
|
||||
self.deriv_balancer = ActivationBalancer(
|
||||
channels, channel_dim=1,
|
||||
channels, channel_dim=-1,
|
||||
min_positive=0.05, max_positive=1.0,
|
||||
max_abs=20.0,
|
||||
)
|
||||
@ -1439,8 +1462,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||
x = torch.matmul(attn_weights, x)
|
||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||
x = self.deriv_balancer(x)
|
||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||
x = self.deriv_balancer(x)
|
||||
x = self.activation(x)
|
||||
x = self.out_proj(x)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user