Add a multiplication to NonlinAttentionModule

This commit is contained in:
Daniel Povey 2023-01-14 20:41:30 +08:00
parent ec8804283c
commit eeadc3b0cc

View File

@ -1602,7 +1602,7 @@ class NonlinAttention(nn.Module):
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.in_proj = nn.Linear(channels, hidden_channels * 2, bias=True) self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
# balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
# because we noticed that well-trained instances of this module have abs-value before the sigmoid # because we noticed that well-trained instances of this module have abs-value before the sigmoid
@ -1617,7 +1617,10 @@ class NonlinAttention(nn.Module):
) )
self.tanh = nn.Tanh() self.tanh = nn.Tanh()
self.activation = Identity() # for diagnostics. self.identity1 = Identity() # for diagnostics.
self.identity2 = Identity() # for diagnostics.
self.identity3 = Identity() # for diagnostics.
self.out_proj = ScaledLinear(hidden_channels, channels, self.out_proj = ScaledLinear(hidden_channels, channels,
bias=True, bias=True,
initial_scale=0.05) initial_scale=0.05)
@ -1652,16 +1655,17 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
(seq_len, batch_size, _) = x.shape (seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels hidden_channels = self.hidden_channels
s = x[..., hidden_channels:] s, x, y = x.chunk(3, dim=-1)
x = x[..., :hidden_channels]
# s will go through tanh.
s = self.balancer(s) s = self.balancer(s)
s = self.tanh(s) s = self.tanh(s)
s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels)
x = self.whiten1(x) x = self.whiten1(x)
x = self.activation(x) # diagnostics only, it's the identity.
x = x * s x = x * s
x = self.identity1(x) # diagnostics only, it's the identity.
(seq_len, batch_size, embed_dim) = x.shape (seq_len, batch_size, embed_dim) = x.shape
num_heads = attn_weights.shape[0] num_heads = attn_weights.shape[0]
@ -1673,6 +1677,11 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
# now x: (num_heads, batch_size, seq_len, head_dim) # now x: (num_heads, batch_size, seq_len, head_dim)
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
y = self.identity2(y)
x = x * y
x = self.identity3(x)
x = self.out_proj(x) x = self.out_proj(x)
x = self.whiten2(x) x = self.whiten2(x)
return x return x