Changes for debugging/stats.
This commit is contained in:
parent
48d699c94b
commit
faed28ba6a
@ -1434,6 +1434,7 @@ class NonlinAttentionModule(nn.Module):
|
|||||||
min_abs=0.2, max_abs=10.0,
|
min_abs=0.2, max_abs=10.0,
|
||||||
min_prob=0.05,
|
min_prob=0.05,
|
||||||
)
|
)
|
||||||
|
self.pre_sigmoid = Identity() # for diagnostics.
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
self.activation = Identity() # for diagnostics.
|
self.activation = Identity() # for diagnostics.
|
||||||
@ -1475,6 +1476,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
|
|||||||
|
|
||||||
v = self.whiten1(v)
|
v = self.whiten1(v)
|
||||||
# GLU mechanism
|
# GLU mechanism
|
||||||
|
s = self.pre_sigmoid(s)
|
||||||
x = self.sigmoid(s) * v
|
x = self.sigmoid(s) * v
|
||||||
x = self.balancer(x)
|
x = self.balancer(x)
|
||||||
|
|
||||||
@ -1537,6 +1539,9 @@ class ConvolutionModule(nn.Module):
|
|||||||
max_abs=10.0, min_positive=0.05, max_positive=1.0
|
max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.pre_sigmoid = Identity() # before sigmoid; for diagnostics.
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
self.depthwise_conv = nn.Conv1d(
|
self.depthwise_conv = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
@ -1585,7 +1590,13 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||||
x = self.deriv_balancer1(x)
|
x = self.deriv_balancer1(x)
|
||||||
x = nn.functional.glu(x, dim=-1) # (time, batch, channels)
|
|
||||||
|
x, s = x.chunk(2, dim=-1)
|
||||||
|
s = self.pre_sigmoid(s)
|
||||||
|
s = self.sigmoid(s)
|
||||||
|
x = x * s
|
||||||
|
|
||||||
|
# (time, batch, channels)
|
||||||
|
|
||||||
# exchange the temporal dimension and the feature dimension
|
# exchange the temporal dimension and the feature dimension
|
||||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user