Adjust balancers of modules; most significant change is to make min_abs of ff2 balancer from 0.5 to 0.1

This commit is contained in:
Daniel Povey 2022-12-31 14:38:00 +08:00
parent a0c35adca0
commit 577c3ad390

View File

@ -442,7 +442,7 @@ class ZipformerEncoderLayer(nn.Module):
feedforward_dim, feedforward_dim,
dropout) dropout)
self.nonlin_attention_module = NonlinAttentionModule(embed_dim, self.nonlin_attention = NonlinAttention(embed_dim,
hidden_channels=embed_dim // 4) hidden_channels=embed_dim // 4)
@ -461,13 +461,35 @@ class ZipformerEncoderLayer(nn.Module):
min_positive=0.45, max_positive=0.55, min_positive=0.45, max_positive=0.55,
min_abs=1.0, max_abs=4.0, min_abs=1.0, max_abs=4.0,
) )
# balancer for output of NonlinAttentionModule
self.balancer_na = Balancer(
embed_dim, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
prob=0.05, # out of concern for memory usage
)
# balancer for output of AttentionSqueezeModule
self.balancer_as = Balancer(
embed_dim, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
prob=0.05, # out of concern for memory usage
)
# balancer for output of feedforward2, prevent it from staying too # balancer for output of feedforward2, prevent it from staying too
# small. give this a very small probability, even at the start of # small. give this a very small probability, even at the start of
# training, it's to fix a rare problem and it's OK to fix it slowly. # training, it's to fix a rare problem and it's OK to fix it slowly.
self.balancer_ff2 = Balancer( self.balancer_ff2 = Balancer(
embed_dim, channel_dim=-1, embed_dim, channel_dim=-1,
min_positive=0.45, max_positive=0.55, min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.0), (8000.0, 0.5), default=0.0), min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
max_abs=2.0, max_abs=2.0,
prob=0.05, prob=0.05,
) )
@ -550,7 +572,7 @@ class ZipformerEncoderLayer(nn.Module):
) )
# else rely on the ones passed in # else rely on the ones passed in
# use different heads for nonlin_attention_module and attention_squeeze, depending # use different heads for nonlin_attention and attention_squeeze, depending
# whether this module has its on self_attn_weights submodule or is borrowing # whether this module has its on self_attn_weights submodule or is borrowing
# attention weights from another one. # attention weights from another one.
head_offset = 0 if self.self_attn_weights is not None else 2 head_offset = 0 if self.self_attn_weights is not None else 2
@ -569,14 +591,15 @@ class ZipformerEncoderLayer(nn.Module):
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1) selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
src = src + self.nonlin_attention_module(src, src = src + self.balancer_na(self.nonlin_attention(src,
selected_attn_weights[0:1]) selected_attn_weights[0:1]))
src = src + self.feed_forward1(src) src = src + self.feed_forward1(src)
# pooling module # pooling module
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
src = src + self.attention_squeeze(src, selected_attn_weights[1:2]) src = src + self.balancer_as(
self.attention_squeeze(src, selected_attn_weights[1:2]))
if torch.jit.is_scripting() or use_self_attn: if torch.jit.is_scripting() or use_self_attn:
src = src + self.self_attn( src = src + self.self_attn(
@ -1359,14 +1382,6 @@ class AttentionSqueeze(nn.Module):
self.out_proj = ScaledLinear(hidden_dim, embed_dim, self.out_proj = ScaledLinear(hidden_dim, embed_dim,
bias=False, initial_scale=0.05) bias=False, initial_scale=0.05)
self.out_balancer = Balancer(
embed_dim, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
prob=0.05, # out of concern for memory usage
)
def forward(self, def forward(self,
x: Tensor, x: Tensor,
attn_weights: Tensor): attn_weights: Tensor):
@ -1402,7 +1417,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
x = x * scales x = x * scales
x = self.activation(x) # Identity only. For diagnostics. x = self.activation(x) # Identity only. For diagnostics.
x = self.out_proj(x) x = self.out_proj(x)
x = self.out_balancer(x)
return x return x
@ -1443,7 +1457,7 @@ class FeedforwardModule(nn.Module):
return x return x
class NonlinAttentionModule(nn.Module): class NonlinAttention(nn.Module):
"""This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
from the attention module) in place of actual convolution. We also took out the second nonlinearity, the from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
one after the attention mechanism. one after the attention mechanism.
@ -1467,7 +1481,7 @@ class NonlinAttentionModule(nn.Module):
# because we noticed that well-trained instances of this module have abs-value before the sigmoid # because we noticed that well-trained instances of this module have abs-value before the sigmoid
# starting from about 3, and poorly-trained instances of the module have smaller abs values # starting from about 3, and poorly-trained instances of the module have smaller abs values
# before the sigmoid. # before the sigmoid.
self.balancer1 = Balancer( self.balancer = Balancer(
hidden_channels, channel_dim=-1, hidden_channels, channel_dim=-1,
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
@ -1493,13 +1507,6 @@ class NonlinAttentionModule(nn.Module):
prob=(0.025, 0.25), prob=(0.025, 0.25),
grad_scale=0.01) grad_scale=0.01)
self.balancer2 = Balancer(
channels, channel_dim=-1,
min_positive=0.3, max_positive=0.7,
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
prob=0.05, # out of concern for memory usage
)
def forward(self, def forward(self,
x: Tensor, x: Tensor,
@ -1521,7 +1528,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
s = x[..., hidden_channels:] s = x[..., hidden_channels:]
x = x[..., :hidden_channels] x = x[..., :hidden_channels]
s = self.balancer1(s) s = self.balancer(s)
s = self.tanh(s) s = self.tanh(s)
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
@ -1541,8 +1548,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
x = self.out_proj(x) x = self.out_proj(x)
x = self.whiten2(x) x = self.whiten2(x)
x = self.balancer2(x)
return x return x