mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Adjust balancers of modules; most significant change is to make min_abs of ff2 balancer from 0.5 to 0.1
This commit is contained in:
parent
a0c35adca0
commit
577c3ad390
@ -442,7 +442,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
feedforward_dim,
|
feedforward_dim,
|
||||||
dropout)
|
dropout)
|
||||||
|
|
||||||
self.nonlin_attention_module = NonlinAttentionModule(embed_dim,
|
self.nonlin_attention = NonlinAttention(embed_dim,
|
||||||
hidden_channels=embed_dim // 4)
|
hidden_channels=embed_dim // 4)
|
||||||
|
|
||||||
|
|
||||||
@ -461,13 +461,35 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
min_positive=0.45, max_positive=0.55,
|
min_positive=0.45, max_positive=0.55,
|
||||||
min_abs=1.0, max_abs=4.0,
|
min_abs=1.0, max_abs=4.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# balancer for output of NonlinAttentionModule
|
||||||
|
self.balancer_na = Balancer(
|
||||||
|
embed_dim, channel_dim=-1,
|
||||||
|
min_positive=0.3, max_positive=0.7,
|
||||||
|
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
||||||
|
prob=0.05, # out of concern for memory usage
|
||||||
|
)
|
||||||
|
|
||||||
|
# balancer for output of AttentionSqueezeModule
|
||||||
|
self.balancer_as = Balancer(
|
||||||
|
embed_dim, channel_dim=-1,
|
||||||
|
min_positive=0.3, max_positive=0.7,
|
||||||
|
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
||||||
|
prob=0.05, # out of concern for memory usage
|
||||||
|
)
|
||||||
|
|
||||||
# balancer for output of feedforward2, prevent it from staying too
|
# balancer for output of feedforward2, prevent it from staying too
|
||||||
# small. give this a very small probability, even at the start of
|
# small. give this a very small probability, even at the start of
|
||||||
# training, it's to fix a rare problem and it's OK to fix it slowly.
|
# training, it's to fix a rare problem and it's OK to fix it slowly.
|
||||||
self.balancer_ff2 = Balancer(
|
self.balancer_ff2 = Balancer(
|
||||||
embed_dim, channel_dim=-1,
|
embed_dim, channel_dim=-1,
|
||||||
min_positive=0.45, max_positive=0.55,
|
min_positive=0.3, max_positive=0.7,
|
||||||
min_abs=ScheduledFloat((0.0, 0.0), (8000.0, 0.5), default=0.0),
|
min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0),
|
||||||
max_abs=2.0,
|
max_abs=2.0,
|
||||||
prob=0.05,
|
prob=0.05,
|
||||||
)
|
)
|
||||||
@ -550,7 +572,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
# else rely on the ones passed in
|
# else rely on the ones passed in
|
||||||
|
|
||||||
# use different heads for nonlin_attention_module and attention_squeeze, depending
|
# use different heads for nonlin_attention and attention_squeeze, depending
|
||||||
# whether this module has its on self_attn_weights submodule or is borrowing
|
# whether this module has its on self_attn_weights submodule or is borrowing
|
||||||
# attention weights from another one.
|
# attention weights from another one.
|
||||||
head_offset = 0 if self.self_attn_weights is not None else 2
|
head_offset = 0 if self.self_attn_weights is not None else 2
|
||||||
@ -569,14 +591,15 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.nonlin_attention_module(src,
|
src = src + self.balancer_na(self.nonlin_attention(src,
|
||||||
selected_attn_weights[0:1])
|
selected_attn_weights[0:1]))
|
||||||
|
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
# pooling module
|
# pooling module
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.attention_squeeze(src, selected_attn_weights[1:2])
|
src = src + self.balancer_as(
|
||||||
|
self.attention_squeeze(src, selected_attn_weights[1:2]))
|
||||||
|
|
||||||
if torch.jit.is_scripting() or use_self_attn:
|
if torch.jit.is_scripting() or use_self_attn:
|
||||||
src = src + self.self_attn(
|
src = src + self.self_attn(
|
||||||
@ -1359,14 +1382,6 @@ class AttentionSqueeze(nn.Module):
|
|||||||
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
|
self.out_proj = ScaledLinear(hidden_dim, embed_dim,
|
||||||
bias=False, initial_scale=0.05)
|
bias=False, initial_scale=0.05)
|
||||||
|
|
||||||
self.out_balancer = Balancer(
|
|
||||||
embed_dim, channel_dim=-1,
|
|
||||||
min_positive=0.3, max_positive=0.7,
|
|
||||||
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
|
||||||
prob=0.05, # out of concern for memory usage
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
attn_weights: Tensor):
|
attn_weights: Tensor):
|
||||||
@ -1402,7 +1417,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
x = x * scales
|
x = x * scales
|
||||||
x = self.activation(x) # Identity only. For diagnostics.
|
x = self.activation(x) # Identity only. For diagnostics.
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
x = self.out_balancer(x)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -1443,7 +1457,7 @@ class FeedforwardModule(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class NonlinAttentionModule(nn.Module):
|
class NonlinAttention(nn.Module):
|
||||||
"""This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
|
"""This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed
|
||||||
from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
|
from the attention module) in place of actual convolution. We also took out the second nonlinearity, the
|
||||||
one after the attention mechanism.
|
one after the attention mechanism.
|
||||||
@ -1467,7 +1481,7 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
# because we noticed that well-trained instances of this module have abs-value before the sigmoid
|
# because we noticed that well-trained instances of this module have abs-value before the sigmoid
|
||||||
# starting from about 3, and poorly-trained instances of the module have smaller abs values
|
# starting from about 3, and poorly-trained instances of the module have smaller abs values
|
||||||
# before the sigmoid.
|
# before the sigmoid.
|
||||||
self.balancer1 = Balancer(
|
self.balancer = Balancer(
|
||||||
hidden_channels, channel_dim=-1,
|
hidden_channels, channel_dim=-1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
|
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
|
||||||
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
|
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
|
||||||
@ -1493,13 +1507,6 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.01)
|
grad_scale=0.01)
|
||||||
|
|
||||||
self.balancer2 = Balancer(
|
|
||||||
channels, channel_dim=-1,
|
|
||||||
min_positive=0.3, max_positive=0.7,
|
|
||||||
min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)),
|
|
||||||
prob=0.05, # out of concern for memory usage
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@ -1521,7 +1528,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
s = x[..., hidden_channels:]
|
s = x[..., hidden_channels:]
|
||||||
x = x[..., :hidden_channels]
|
x = x[..., :hidden_channels]
|
||||||
|
|
||||||
s = self.balancer1(s)
|
s = self.balancer(s)
|
||||||
s = self.tanh(s)
|
s = self.tanh(s)
|
||||||
|
|
||||||
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
|
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
|
||||||
@ -1541,8 +1548,6 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
|
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
x = self.whiten2(x)
|
x = self.whiten2(x)
|
||||||
x = self.balancer2(x)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user