mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change balancer to whitener for ff module; tighther min/max-pos limit on NonlinAttentionModule; whitener->balancer for AttentionSqueeze.
This commit is contained in:
parent
26916f41e7
commit
7acdaea085
@ -1306,7 +1306,8 @@ class AttentionSqueeze(nn.Module):
|
||||
max_factor=0.02,
|
||||
min_prob=0.1,
|
||||
)
|
||||
self.activation = DoubleSwish() # in bottleneck
|
||||
self.bottleneck_activation = DoubleSwish() # in bottleneck
|
||||
self.activation = Identity() # for diagnostics
|
||||
|
||||
# the next two balancers are only to stop parameter-magnitude 'drift': we have
|
||||
# too many degrees of freedom for the scales of the various activations.
|
||||
@ -1331,11 +1332,12 @@ class AttentionSqueeze(nn.Module):
|
||||
self.out_proj = ScaledLinear(embed_dim, embed_dim,
|
||||
bias=False, initial_scale=0.05)
|
||||
|
||||
self.out_whiten = Whiten(num_groups=1,
|
||||
whitening_limit=10.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
self.out_balancer = ActivationBalancer(
|
||||
embed_dim, channel_dim=-1,
|
||||
min_positive=0.45, max_positive=0.55,
|
||||
min_abs=0.005, max_abs=2.0,
|
||||
min_prob=0.05,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
@ -1358,7 +1360,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
# -> (num_heads, batch_size, seq_len, head_dim)
|
||||
bottleneck = torch.matmul(attn_weights, bottleneck)
|
||||
bottleneck = self.bottleneck_balancer(bottleneck)
|
||||
bottleneck = self.activation(bottleneck)
|
||||
bottleneck = self.bottleneck_activation(bottleneck)
|
||||
bottleneck = bottleneck.permute(2, 1, 0, 3) # (seq_len, batch_size, num_heads, head_dim)
|
||||
bottleneck = bottleneck.reshape(seq_len, batch_size, bottleneck_dim)
|
||||
scales = self.from_bottleneck_proj(bottleneck)
|
||||
@ -1367,8 +1369,9 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
x = self.activation_balancer(x)
|
||||
scales = self.scale_balancer(scales)
|
||||
x = x * scales
|
||||
x = self.activation(x) # Identity only. For diagnostics.
|
||||
x = self.out_proj(x)
|
||||
x = self.out_whiten(x)
|
||||
x = self.out_balancer(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -1388,11 +1391,10 @@ class FeedforwardModule(nn.Module):
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.out_proj = ScaledLinear(feedforward_dim, embed_dim,
|
||||
initial_scale=0.01)
|
||||
self.out_balancer = ActivationBalancer(embed_dim,
|
||||
min_positive=0.4, max_positive=0.6,
|
||||
min_abs=0.01, max_abs=5.0,
|
||||
channel_dim=-1, min_prob=0.1)
|
||||
|
||||
self.out_whitener = Whiten(num_groups=1,
|
||||
whitening_limit=10.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.01)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor):
|
||||
@ -1401,7 +1403,7 @@ class FeedforwardModule(nn.Module):
|
||||
x = self.activation(x)
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
x = self.out_balancer(x)
|
||||
x = self.out_whitener(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -1447,7 +1449,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
# to have a larger mean-offset at the output for some reason.
|
||||
self.out_balancer = ActivationBalancer(
|
||||
channels, channel_dim=-1,
|
||||
min_positive=0.4, max_positive=0.6,
|
||||
min_positive=0.45, max_positive=0.55,
|
||||
min_abs=0.005, max_abs=1.0,
|
||||
min_prob=0.05,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user