Merge branch 'scaled_adam_exp647' into scaled_adam_exp652
# Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
This commit is contained in:
commit
1718b2de44
@ -432,8 +432,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
dropout)
|
dropout)
|
||||||
|
|
||||||
self.nonlin_attention_module = NonlinAttentionModule(embed_dim,
|
self.nonlin_attention_module = NonlinAttentionModule(embed_dim,
|
||||||
hidden_channels=embed_dim // 4,
|
hidden_channels=embed_dim // 4)
|
||||||
ratio=1)
|
|
||||||
|
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(embed_dim,
|
self.conv_module = ConvolutionModule(embed_dim,
|
||||||
@ -1470,22 +1469,19 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
hidden_channels: int,
|
hidden_channels: int,
|
||||||
ratio: int = 1,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.ratio = ratio
|
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
|
|
||||||
assert channels % (ratio * 2) == 0
|
self.in_proj = nn.Linear(channels, hidden_channels * 2, bias=True)
|
||||||
self.in_proj = nn.Linear(channels, hidden_channels + hidden_channels // ratio, bias=True)
|
|
||||||
|
|
||||||
# balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
|
# balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
|
||||||
# because we noticed that well-trained instances of this module have abs-value before the sigmoid
|
# because we noticed that well-trained instances of this module have abs-value before the sigmoid
|
||||||
# starting from about 3, and poorly-trained instances of the module have smaller abs values
|
# starting from about 3, and poorly-trained instances of the module have smaller abs values
|
||||||
# before the sigmoid.
|
# before the sigmoid.
|
||||||
self.balancer1 = ActivationBalancer(
|
self.balancer1 = ActivationBalancer(
|
||||||
hidden_channels // ratio, channel_dim=-1,
|
hidden_channels, channel_dim=-1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
|
min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)),
|
||||||
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
|
max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)),
|
||||||
min_abs=0.75,
|
min_abs=0.75,
|
||||||
@ -1498,22 +1494,23 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
bias=True,
|
bias=True,
|
||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
|
|
||||||
# Have very tight limits on min_positive and max_positive so that it beomes
|
|
||||||
# close to zero mean, as we found that large mean offsets after the
|
|
||||||
# multiplication are associated with poor convergence.
|
|
||||||
# We don't need min_abs and max_abs limits because sharing the in_proj
|
|
||||||
# between the sigmoid-input and activations dictates the scale of the
|
|
||||||
# activations at this point. The code applies those anyway, it's not optional
|
|
||||||
# right now, so just use the default values.
|
|
||||||
self.balancer2 = ActivationBalancer(
|
|
||||||
hidden_channels // ratio, channel_dim=-1,
|
|
||||||
min_positive=0.4, max_positive=0.6,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.whiten = Whiten(num_groups=1,
|
|
||||||
whitening_limit=_whitening_schedule(5.0),
|
self.whiten1 = Whiten(num_groups=1,
|
||||||
prob=(0.025, 0.25),
|
whitening_limit=_whitening_schedule(5.0),
|
||||||
grad_scale=0.01)
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
self.whiten2 = Whiten(num_groups=1,
|
||||||
|
whitening_limit=_whitening_schedule(5.0),
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01)
|
||||||
|
|
||||||
|
self.balancer2 = ActivationBalancer(
|
||||||
|
channels, channel_dim=-1,
|
||||||
|
min_positive=0.45, max_positive=0.55,
|
||||||
|
min_abs=ScheduledFloat((0.0, 0.001), (8000.0, 0.01))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1540,8 +1537,8 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
s = self.balancer1(s)
|
s = self.balancer1(s)
|
||||||
s = self.tanh(s)
|
s = self.tanh(s)
|
||||||
|
|
||||||
s = s.unsqueeze(-1).expand(-1, -1, -1, self.ratio).reshape(seq_len, batch_size,
|
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
|
||||||
hidden_channels)
|
x = self.whiten1(x)
|
||||||
x = self.activation(x) # diagnostics only, it's the identity.
|
x = self.activation(x) # diagnostics only, it's the identity.
|
||||||
x = x * s
|
x = x * s
|
||||||
|
|
||||||
@ -1555,9 +1552,10 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
# now x: (num_heads, batch_size, seq_len, head_dim)
|
# now x: (num_heads, batch_size, seq_len, head_dim)
|
||||||
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
|
||||||
|
|
||||||
x = self.balancer2(x)
|
|
||||||
x = self.whiten(x)
|
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
|
x = self.whiten2(x)
|
||||||
|
x = self.balancer2(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user