Make scale in NonlinAttention have glu nonlinearity.

This commit is contained in:
Daniel Povey 2023-01-15 00:21:01 +08:00
parent eeadc3b0cc
commit 048b6b6259

View File

@ -1602,7 +1602,7 @@ class NonlinAttention(nn.Module):
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) self.in_proj = nn.Linear(channels, hidden_channels * 4, bias=True)
# balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
# because we noticed that well-trained instances of this module have abs-value before the sigmoid # because we noticed that well-trained instances of this module have abs-value before the sigmoid
@ -1655,7 +1655,9 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
(seq_len, batch_size, _) = x.shape (seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels hidden_channels = self.hidden_channels
s, x, y = x.chunk(3, dim=-1) x, y = x.chunk(2, dim=-1)
s, x = x.chunk(2, dim=-1)
# s will go through tanh. # s will go through tanh.
@ -1677,7 +1679,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
# now x: (num_heads, batch_size, seq_len, head_dim) # now x: (num_heads, batch_size, seq_len, head_dim)
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
y = torch.nn.functional.glu(y, dim=-1)
y = self.identity2(y) y = self.identity2(y)
x = x * y x = x * y
x = self.identity3(x) x = self.identity3(x)