Changes for debugging/stats.

This commit is contained in:
Daniel Povey 2022-11-26 18:59:15 +08:00
parent 48d699c94b
commit faed28ba6a

View File

@ -1434,6 +1434,7 @@ class NonlinAttentionModule(nn.Module):
min_abs=0.2, max_abs=10.0,
min_prob=0.05,
)
self.pre_sigmoid = Identity() # for diagnostics.
self.sigmoid = nn.Sigmoid()
self.activation = Identity() # for diagnostics.
@ -1475,6 +1476,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
v = self.whiten1(v)
# GLU mechanism
s = self.pre_sigmoid(s)
x = self.sigmoid(s) * v
x = self.balancer(x)
@ -1537,6 +1539,9 @@ class ConvolutionModule(nn.Module):
max_abs=10.0, min_positive=0.05, max_positive=1.0
)
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
self.sigmoid = nn.Sigmoid()
self.depthwise_conv = nn.Conv1d(
channels,
channels,
@ -1585,7 +1590,13 @@ class ConvolutionModule(nn.Module):
x = self.in_proj(x) # (time, batch, 2*channels)
x = self.deriv_balancer1(x)
x = nn.functional.glu(x, dim=-1) # (time, batch, channels)
x, s = x.chunk(2, dim=-1)
s = self.pre_sigmoid(s)
s = self.sigmoid(s)
x = x * s
# (time, batch, channels)
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).