Use half the dim for values, vs. keys and queries.

This commit is contained in:
Daniel Povey 2022-10-17 22:15:06 +08:00
parent 3f495cd197
commit 2675944f01

View File

@ -843,7 +843,7 @@ class RelPositionMultiheadAttention(nn.Module):
# the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5, dividing it between the query and key.
self.in_proj = ScaledLinear(embed_dim, 7 * attention_dim // 2, bias=True,
self.in_proj = ScaledLinear(embed_dim, 3 * attention_dim, bias=True,
initial_scale=self.head_dim**-0.25)
# self.whiten_values is applied on the values in forward()
@ -863,14 +863,14 @@ class RelPositionMultiheadAttention(nn.Module):
self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False,
initial_scale=0.05)
self.in_balancer = ActivationBalancer(7 * attention_dim // 2,
self.in_balancer = ActivationBalancer(3 * attention_dim,
channel_dim=-1, max_abs=5.0)
self.out_proj = ScaledLinear(
attention_dim, embed_dim, bias=True, initial_scale=0.05
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
)
self.in_proj2 = nn.Linear(embed_dim, attention_dim, bias=False)
self.out_proj2 = ScaledLinear(attention_dim, embed_dim, bias=True,
self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False)
self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True,
initial_scale=0.05)
# self.whiten_values2 is applied on the values in forward2()
self.whiten_values2 = Whiten(num_groups=num_heads,
@ -1005,10 +1005,8 @@ class RelPositionMultiheadAttention(nn.Module):
# self-attention
q = x[...,:attention_dim]
k = x[...,attention_dim:2*attention_dim]
v = x[...,2*attention_dim:3*attention_dim]
p = x[...,3*attention_dim:]
q, k, pv = x.chunk(3, dim=-1)
p, v = pv.chunk(2, dim=-1)
k = self.whiten_keys(k) # does nothing in the forward pass.
@ -1066,7 +1064,7 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.reshape(seq_len, bsz, num_heads, head_dim)
p = p.reshape(seq_len, bsz, num_heads, head_dim // 2)
k = k.reshape(seq_len, bsz, num_heads, head_dim)
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1)
if key_padding_mask is not None:
@ -1140,11 +1138,11 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim]
assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2]
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, bsz, attention_dim)
.view(seq_len, bsz, attention_dim // 2)
)
attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias
@ -1173,9 +1171,9 @@ class RelPositionMultiheadAttention(nn.Module):
# v: (tgt_len, bsz, embed_dim // 2)
v = self.in_proj2(x)
v = self.whiten_values2(v) # does nothing in the forward pass.
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1)
# now v: (bsz * num_heads, seq_len, head_dim)
# now v: (bsz * num_heads, seq_len, head_dim // 2)
attn_output = torch.bmm(attn_weights, v)
if random.random() < 0.001 or __name__ == "__main__":
@ -1185,7 +1183,7 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(seq_len, bsz, self.attention_dim)
.view(seq_len, bsz, self.attention_dim // 2)
)
# returned value is of shape (seq_len, bsz, embed_dim), like x.
return self.out_proj2(attn_output)