mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Changes for debugging/stats.
This commit is contained in:
parent
48d699c94b
commit
faed28ba6a
@ -1434,6 +1434,7 @@ class NonlinAttentionModule(nn.Module):
|
||||
min_abs=0.2, max_abs=10.0,
|
||||
min_prob=0.05,
|
||||
)
|
||||
self.pre_sigmoid = Identity() # for diagnostics.
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.activation = Identity() # for diagnostics.
|
||||
@ -1475,6 +1476,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
v = self.whiten1(v)
|
||||
# GLU mechanism
|
||||
s = self.pre_sigmoid(s)
|
||||
x = self.sigmoid(s) * v
|
||||
x = self.balancer(x)
|
||||
|
||||
@ -1537,6 +1539,9 @@ class ConvolutionModule(nn.Module):
|
||||
max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -1585,7 +1590,13 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||
x = self.deriv_balancer1(x)
|
||||
x = nn.functional.glu(x, dim=-1) # (time, batch, channels)
|
||||
|
||||
x, s = x.chunk(2, dim=-1)
|
||||
s = self.pre_sigmoid(s)
|
||||
s = self.sigmoid(s)
|
||||
x = x * s
|
||||
|
||||
# (time, batch, channels)
|
||||
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user