Use half the dim per head, in self_attn layers.

This commit is contained in:
Daniel Povey 2022-09-24 15:40:44 +08:00
parent ce3f59d9c7
commit 71b3756ada

View File

@ -469,27 +469,27 @@ class RelPositionMultiheadAttention(nn.Module):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.head_dim = embed_dim // (num_heads * 2)
assert (
self.head_dim * num_heads == self.embed_dim
self.head_dim * num_heads == self.embed_dim // 2
), "embed_dim must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
self.in_balancer = ActivationBalancer(3 * embed_dim,
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
channel_dim=-1, max_abs=5.0,
max_var_per_eig=0.2)
self.proj_balancer = ActivationBalancer(embed_dim,
self.proj_balancer = ActivationBalancer(embed_dim // 2,
channel_dim=-1, max_abs=10.0,
min_positive=0.0, max_positive=1.0)
self.out_proj = ScaledLinear(
embed_dim, embed_dim, bias=True, initial_scale=0.5
embed_dim // 2, embed_dim, bias=True, initial_scale=0.5
)
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
# linear transformation for positional encoding.
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
@ -663,9 +663,9 @@ class RelPositionMultiheadAttention(nn.Module):
tgt_len, bsz, _ = x.size()
head_dim = embed_dim // num_heads
head_dim = embed_dim // (num_heads * 2)
assert (
head_dim * num_heads == embed_dim
head_dim * num_heads == embed_dim // 2
), "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
@ -815,7 +815,7 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, embed_dim)
.view(tgt_len, bsz, embed_dim // 2)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias