mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp503' into scaled_adam_exp505
This commit is contained in:
commit
e19118a966
@ -1442,6 +1442,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
min_abs=0.2, max_abs=10.0,
|
||||
min_prob=0.05,
|
||||
)
|
||||
self.pre_sigmoid = Identity() # for diagnostics.
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.activation = Identity() # for diagnostics.
|
||||
@ -1484,6 +1485,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
v = self.whiten1(v)
|
||||
# GLU mechanism
|
||||
s = self.pre_sigmoid(s)
|
||||
x = self.sigmoid(s) * v
|
||||
x = self.balancer(x)
|
||||
|
||||
@ -1546,6 +1548,9 @@ class ConvolutionModule(nn.Module):
|
||||
max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -1594,7 +1599,13 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||
x = self.deriv_balancer1(x)
|
||||
x = nn.functional.glu(x, dim=-1) # (time, batch, channels)
|
||||
|
||||
x, s = x.chunk(2, dim=-1)
|
||||
s = self.pre_sigmoid(s)
|
||||
s = self.sigmoid(s)
|
||||
x = x * s
|
||||
|
||||
# (time, batch, channels)
|
||||
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user