Merge branch 'scaled_adam_exp912' into scaled_adam_exp994

This commit is contained in:
Daniel Povey 2023-02-08 21:15:21 +08:00
commit b2fb504aee

View File

@ -1592,7 +1592,7 @@ class NonlinAttention(nn.Module):
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.in_proj = nn.Linear(channels, hidden_channels * 4, bias=True) self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True)
# balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0,
# because we noticed that well-trained instances of this module have abs-value before the sigmoid # because we noticed that well-trained instances of this module have abs-value before the sigmoid
@ -1645,9 +1645,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
(seq_len, batch_size, _) = x.shape (seq_len, batch_size, _) = x.shape
hidden_channels = self.hidden_channels hidden_channels = self.hidden_channels
x, y = x.chunk(2, dim=-1) s, x, y = x.chunk(3, dim=-1)
s, x = x.chunk(2, dim=-1)
# s will go through tanh. # s will go through tanh.
@ -1669,7 +1667,7 @@ attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len)
# now x: (num_heads, batch_size, seq_len, head_dim) # now x: (num_heads, batch_size, seq_len, head_dim)
x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1)
y = torch.nn.functional.glu(y, dim=-1)
y = self.identity2(y) y = self.identity2(y)
x = x * y x = x * y
x = self.identity3(x) x = self.identity3(x)