mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp569' into scaled_adam_exp576
This commit is contained in:
commit
cac1a8b860
@ -1471,10 +1471,10 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
channels // (2 * ratio), channel_dim=-1,
|
channels // (2 * ratio), channel_dim=-1,
|
||||||
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
|
||||||
max_positive=1.0,
|
max_positive=1.0,
|
||||||
min_abs=1.5,
|
min_abs=0.75,
|
||||||
max_abs=ScheduledFloat((0.0, 5.0), (8000.0, 10.0), default=1.0),
|
max_abs=ScheduledFloat((0.0, 2.5), (8000.0, 5.0), default=1.0),
|
||||||
)
|
)
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.tanh = nn.Tanh()
|
||||||
|
|
||||||
self.activation = Identity() # for diagnostics.
|
self.activation = Identity() # for diagnostics.
|
||||||
self.out_proj = ScaledLinear(channels // 2, channels,
|
self.out_proj = ScaledLinear(channels // 2, channels,
|
||||||
@ -1508,7 +1508,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
x = x[..., :num_channels // 2]
|
x = x[..., :num_channels // 2]
|
||||||
|
|
||||||
s = self.balancer(s)
|
s = self.balancer(s)
|
||||||
s = self.sigmoid(s)
|
s = self.tanh(s)
|
||||||
|
|
||||||
s = s.unsqueeze(-1).expand(-1, -1, -1, self.ratio).reshape(seq_len, batch_size, num_channels // 2)
|
s = s.unsqueeze(-1).expand(-1, -1, -1, self.ratio).reshape(seq_len, batch_size, num_channels // 2)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user