Changes and bug-fixes RE balancers; restore activation in AttentionSqueeze, remove in NonlinAttention.
This commit is contained in:
parent
9fe6add587
commit
836c72dd36
@ -32,8 +32,7 @@ from scaling import (
|
|||||||
ScaledConv1d,
|
ScaledConv1d,
|
||||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||||
Whiten,
|
Whiten,
|
||||||
Identity,
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||||
_diag,
|
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
softmax,
|
softmax,
|
||||||
ScheduledFloat,
|
ScheduledFloat,
|
||||||
@ -212,7 +211,7 @@ class Zipformer(EncoderInterface):
|
|||||||
for i in range(len(z)):
|
for i in range(len(z)):
|
||||||
if i <= 1 or z[i-1] <= z[i]:
|
if i <= 1 or z[i-1] <= z[i]:
|
||||||
skip_layers.append(None)
|
skip_layers.append(None)
|
||||||
skip_modules.append(nn.Identity())
|
skip_modules.append(Identity())
|
||||||
else:
|
else:
|
||||||
# TEMP
|
# TEMP
|
||||||
for j in range(i-2, -1, -1):
|
for j in range(i-2, -1, -1):
|
||||||
@ -1300,11 +1299,12 @@ class AttentionSqueeze(nn.Module):
|
|||||||
self.bottleneck_balancer = ActivationBalancer(
|
self.bottleneck_balancer = ActivationBalancer(
|
||||||
bottleneck_dim, channel_dim=-1,
|
bottleneck_dim, channel_dim=-1,
|
||||||
min_positive=0.05, max_positive=0.95,
|
min_positive=0.05, max_positive=0.95,
|
||||||
min_abs=0.2,
|
min_abs=0.05,
|
||||||
max_abs=1.0,
|
max_abs=2.0,
|
||||||
max_factor=0.02,
|
max_factor=0.02,
|
||||||
min_prob=0.1,
|
min_prob=0.1,
|
||||||
)
|
)
|
||||||
|
self.activation = DoubleSwish() # in bottleneck
|
||||||
|
|
||||||
# the next two balancers are only to stop parameter-magnitude 'drift': we have
|
# the next two balancers are only to stop parameter-magnitude 'drift': we have
|
||||||
# too many degrees of freedom for the scales of the various activations.
|
# too many degrees of freedom for the scales of the various activations.
|
||||||
@ -1313,15 +1313,15 @@ class AttentionSqueeze(nn.Module):
|
|||||||
# put no limitation on the signs (so: min_positive=0, max_positive=1).
|
# put no limitation on the signs (so: min_positive=0, max_positive=1).
|
||||||
self.scale_balancer = ActivationBalancer(
|
self.scale_balancer = ActivationBalancer(
|
||||||
embed_dim, channel_dim=-1,
|
embed_dim, channel_dim=-1,
|
||||||
min_positive=0.0, max_positive=1.0,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.2, max_abs=1.0,
|
min_abs=0.2, max_abs=1.0,
|
||||||
min_prob=0.025,
|
min_prob=0.05,
|
||||||
)
|
)
|
||||||
self.activation_balancer = ActivationBalancer(
|
self.activation_balancer = ActivationBalancer(
|
||||||
embed_dim, channel_dim=-1,
|
embed_dim, channel_dim=-1,
|
||||||
min_positive=0.0, max_positive=1.0,
|
min_positive=0.2, max_positive=0.8,
|
||||||
min_abs=0.2, max_abs=1.0,
|
min_abs=0.2, max_abs=1.0,
|
||||||
min_prob=0.025,
|
min_prob=0.05,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
|
self.from_bottleneck_proj = ScaledLinear(bottleneck_dim, embed_dim)
|
||||||
@ -1356,14 +1356,14 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
# -> (num_heads, batch_size, seq_len, head_dim)
|
# -> (num_heads, batch_size, seq_len, head_dim)
|
||||||
bottleneck = torch.matmul(attn_weights, bottleneck)
|
bottleneck = torch.matmul(attn_weights, bottleneck)
|
||||||
bottleneck = self.bottleneck_balancer(bottleneck)
|
bottleneck = self.bottleneck_balancer(bottleneck)
|
||||||
|
bottleneck = self.activation(bottleneck)
|
||||||
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
|
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
|
||||||
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
|
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
|
||||||
scales = self.from_bottleneck_proj(bottleneck)
|
scales = self.from_bottleneck_proj(bottleneck)
|
||||||
|
|
||||||
scales = self.scale_balancer(scales)
|
|
||||||
x = self.activation_balancer(x)
|
|
||||||
|
|
||||||
x = self.in_proj(x)
|
x = self.in_proj(x)
|
||||||
|
x = self.activation_balancer(x)
|
||||||
|
scales = self.scale_balancer(scales)
|
||||||
x = x * scales
|
x = x * scales
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
x = self.out_whiten(x)
|
x = self.out_whiten(x)
|
||||||
@ -1399,7 +1399,8 @@ class FeedforwardModule(nn.Module):
|
|||||||
|
|
||||||
class NonlinAttentionModule(nn.Module):
|
class NonlinAttentionModule(nn.Module):
|
||||||
"""This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
|
"""This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
|
||||||
from the attention module) in place of actual convolution.
|
from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
|
||||||
|
one after the attention mechanism.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channels (int): The number of channels of conv layers.
|
channels (int): The number of channels of conv layers.
|
||||||
@ -1410,24 +1411,9 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# to_scale and to_value are analogous to pointwise_conv1 in ConvolutionModule
|
self.in_proj = nn.Linear(channels, 2 * channels, bias=True)
|
||||||
# we make them separate because we need an extra degree of freedom for the
|
|
||||||
# scale, as the attention weights are constrained to sum to one so cannot
|
|
||||||
# provide the degree of freedom for the scale of the features before
|
|
||||||
# self.activation().
|
|
||||||
self.to_scale = nn.Linear(channels, channels, bias=True)
|
|
||||||
self.to_value = nn.Linear(channels, channels, bias=True)
|
|
||||||
|
|
||||||
|
|
||||||
# deriv_balancer corresponds to deriv_balancer2 in ConvolutionModule
|
|
||||||
self.deriv_balancer = ActivationBalancer(
|
|
||||||
channels, channel_dim=-1,
|
|
||||||
min_positive=0.05, max_positive=1.0,
|
|
||||||
max_abs=20.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.activation = DoubleSwish()
|
|
||||||
|
|
||||||
|
self.activation = Identity() # for diagnostics.
|
||||||
self.out_proj = ScaledLinear(channels, channels,
|
self.out_proj = ScaledLinear(channels, channels,
|
||||||
bias=True,
|
bias=True,
|
||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
@ -1443,8 +1429,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
Returns:
|
Returns:
|
||||||
a Tensor with the same shape as x
|
a Tensor with the same shape as x
|
||||||
"""
|
"""
|
||||||
s = self.to_scale(x)
|
v, s = self.in_proj(x).chunk(2, dim=-1)
|
||||||
v = self.to_value(x)
|
|
||||||
if self.training and random.random() < 0.02:
|
if self.training and random.random() < 0.02:
|
||||||
# prevent the inputs to the sigmoid from getting very large (this is
|
# prevent the inputs to the sigmoid from getting very large (this is
|
||||||
# hopefully quite a rare phenomenon, so we are giving this path a
|
# hopefully quite a rare phenomenon, so we are giving this path a
|
||||||
@ -1463,8 +1449,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
x = torch.matmul(attn_weights, x)
|
x = torch.matmul(attn_weights, x)
|
||||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||||
x = self.deriv_balancer(x)
|
|
||||||
x = self.activation(x)
|
x = self.activation(x) # diagnostics only, it's the identity.
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user