mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove activation in AttentionSqueeze; add balancers; fix bugs RE balancers.
This commit is contained in:
parent
e9806950f5
commit
4e21db07f6
@ -122,40 +122,6 @@ def _compute_sign_factor(x: Tensor,
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
|
||||||
"""
|
|
||||||
This object is used in class ActivationBalancer when the user specified
|
|
||||||
min_positive=0, max_positive=1, so there are no constraints on the signs
|
|
||||||
of the activations and only the absolute value has a constraint.
|
|
||||||
"""
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
x: Tensor,
|
|
||||||
sign_factor: Tensor,
|
|
||||||
scale_factor: Tensor,
|
|
||||||
channel_dim: int,
|
|
||||||
) -> Tensor:
|
|
||||||
if channel_dim < 0:
|
|
||||||
channel_dim += x.ndim
|
|
||||||
ctx.channel_dim = channel_dim
|
|
||||||
xgt0 = (x > 0)
|
|
||||||
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(
|
|
||||||
ctx, x_grad: Tensor
|
|
||||||
) -> Tuple[Tensor, None, None, None]:
|
|
||||||
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
|
||||||
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
|
||||||
sign_factor = sign_factor.unsqueeze(-1)
|
|
||||||
scale_factor = scale_factor.unsqueeze(-1)
|
|
||||||
|
|
||||||
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
|
||||||
neg_delta_grad = x_grad.abs() * factor
|
|
||||||
return x_grad - neg_delta_grad, None, None, None,
|
|
||||||
|
|
||||||
|
|
||||||
def random_cast_to_half(x: Tensor,
|
def random_cast_to_half(x: Tensor,
|
||||||
@ -465,6 +431,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
# loop.
|
# loop.
|
||||||
self.batch_count = 0
|
self.batch_count = 0
|
||||||
|
|
||||||
|
# actually self.num_channels is no longer needed except for an assertion.
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
self.channel_dim = channel_dim
|
self.channel_dim = channel_dim
|
||||||
self.min_positive = min_positive
|
self.min_positive = min_positive
|
||||||
@ -489,6 +456,7 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
|
prob = max(self.min_prob, 0.5 ** (1 + (self.batch_count / 4000.0)))
|
||||||
|
|
||||||
if random.random() < prob:
|
if random.random() < prob:
|
||||||
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
sign_gain_factor = 0.5
|
sign_gain_factor = 0.5
|
||||||
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
||||||
sign_factor = _compute_sign_factor(x, self.channel_dim,
|
sign_factor = _compute_sign_factor(x, self.channel_dim,
|
||||||
|
|||||||
@ -1288,25 +1288,46 @@ class AttentionSqueeze(nn.Module):
|
|||||||
bottleneck_dim,
|
bottleneck_dim,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
# the main reason for this balancer is to keep the bottleneck activations in a "reasonable"
|
||||||
|
# dynamic range, to avoid parameter-size 'drift' where to_bottleneck_proj gets large and from_bottleneck_proj
|
||||||
|
# gets small or vice versa.
|
||||||
# Caution: this cannot work correctly with an extremeley small batch size, e.g. if
|
# Caution: this cannot work correctly with an extremeley small batch size, e.g. if
|
||||||
# we were training with a single very long audio sequence, or just 2 or 3 sequences
|
# we were training with a single very long audio sequence, or just 2 or 3 sequences
|
||||||
# at a time. We make max_factor small to reduce the harm this could cause
|
# at a time. We make max_factor small to reduce the harm this could cause
|
||||||
# (although when the grads get back past the averaging operation they would
|
# (although when the grads get back past the averaging operation they would
|
||||||
# be quite small and would probably not hurt the rest of the model much.)
|
# be quite small and would probably not hurt the rest of the model much.)
|
||||||
self.balancer = ActivationBalancer(
|
self.bottleneck_balancer = ActivationBalancer(
|
||||||
embed_dim, channel_dim=-1,
|
bottleneck_dim, channel_dim=-1,
|
||||||
min_positive=0.05, max_positive=0.95,
|
min_positive=0.05, max_positive=0.95,
|
||||||
min_abs=0.1,
|
min_abs=0.2,
|
||||||
max_abs=50.0,
|
max_abs=1.0,
|
||||||
max_factor=0.02,
|
max_factor=0.02,
|
||||||
min_prob=0.2,
|
min_prob=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# the next two balancers are only to stop parameter-magnitude 'drift': we have
|
||||||
|
# too many degrees of freedom for the scales of the various activations.
|
||||||
|
# Make them run with very low probability, since only a small application of
|
||||||
|
# these balancers should be enough to stop such "drift"; and, for speed,
|
||||||
|
# put no limitation on the signs (so: min_positive=0, max_positive=1).
|
||||||
|
self.scale_balancer = ActivationBalancer(
|
||||||
|
embed_dim, channel_dim=-1,
|
||||||
|
min_positive=0.0, max_positive=1.0,
|
||||||
|
min_abs=0.2, max_abs=1.0,
|
||||||
|
min_prob=0.025,
|
||||||
|
)
|
||||||
|
self.activation_balancer = ActivationBalancer(
|
||||||
|
embed_dim, channel_dim=-1,
|
||||||
|
min_positive=0.0, max_positive=1.0,
|
||||||
|
min_abs=0.2, max_abs=1.0,
|
||||||
|
min_prob=0.025,
|
||||||
)
|
)
|
||||||
self.activation = DoubleSwish()
|
|
||||||
|
|
||||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
|
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
|
||||||
|
|
||||||
self.out_proj = ScaledLinear(embed_dim, embed_dim,
|
self.out_proj = ScaledLinear(embed_dim, embed_dim,
|
||||||
bias=False, initial_scale=0.1)
|
bias=False, initial_scale=0.05)
|
||||||
|
|
||||||
self.out_whiten = Whiten(num_groups=1,
|
self.out_whiten = Whiten(num_groups=1,
|
||||||
whitening_limit=10.0,
|
whitening_limit=10.0,
|
||||||
@ -1334,12 +1355,14 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
# (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim)
|
# (num_heads, batch_size, seq_len, seq_len) x (num_heads, batch_size, seq_len, head_dim)
|
||||||
# -> (num_heads, batch_size, seq_len, head_dim)
|
# -> (num_heads, batch_size, seq_len, head_dim)
|
||||||
bottleneck = torch.matmul(attn_weights, bottleneck)
|
bottleneck = torch.matmul(attn_weights, bottleneck)
|
||||||
bottleneck = self.balancer(bottleneck)
|
bottleneck = self.bottleneck_balancer(bottleneck)
|
||||||
bottleneck = self.activation(bottleneck)
|
|
||||||
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
|
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
|
||||||
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
|
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
|
||||||
scales = self.from_bottleneck_proj(bottleneck)
|
scales = self.from_bottleneck_proj(bottleneck)
|
||||||
|
|
||||||
|
scales = self.scale_balancer(scales)
|
||||||
|
x = self.activation_balancer(x)
|
||||||
|
|
||||||
x = self.in_proj(x)
|
x = self.in_proj(x)
|
||||||
x = x * scales
|
x = x * scales
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
@ -1398,7 +1421,7 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
|
|
||||||
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule
|
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionMOdule
|
||||||
self.deriv_balancer = ActivationBalancer(
|
self.deriv_balancer = ActivationBalancer(
|
||||||
channels, channel_dim=1,
|
channels, channel_dim=-1,
|
||||||
min_positive=0.05, max_positive=1.0,
|
min_positive=0.05, max_positive=1.0,
|
||||||
max_abs=20.0,
|
max_abs=20.0,
|
||||||
)
|
)
|
||||||
@ -1439,8 +1462,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||||
x = torch.matmul(attn_weights, x)
|
x = torch.matmul(attn_weights, x)
|
||||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||||
x = self.deriv_balancer(x)
|
|
||||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||||
|
x = self.deriv_balancer(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user