Use half the dim for values, vs. keys and queries.

This commit is contained in:
Daniel Povey 2022-10-17 22:15:06 +08:00
parent 3f495cd197
commit 2675944f01

View File

@ -843,7 +843,7 @@ class RelPositionMultiheadAttention(nn.Module):
# the initial_scale is supposed to take over the "scaling" factor of # the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5, dividing it between the query and key. # head_dim ** -0.5, dividing it between the query and key.
self.in_proj = ScaledLinear(embed_dim, 7 * attention_dim // 2, bias=True, self.in_proj = ScaledLinear(embed_dim, 3 * attention_dim, bias=True,
initial_scale=self.head_dim**-0.25) initial_scale=self.head_dim**-0.25)
# self.whiten_values is applied on the values in forward() # self.whiten_values is applied on the values in forward()
@ -863,14 +863,14 @@ class RelPositionMultiheadAttention(nn.Module):
self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False, self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False,
initial_scale=0.05) initial_scale=0.05)
self.in_balancer = ActivationBalancer(7 * attention_dim // 2, self.in_balancer = ActivationBalancer(3 * attention_dim,
channel_dim=-1, max_abs=5.0) channel_dim=-1, max_abs=5.0)
self.out_proj = ScaledLinear( self.out_proj = ScaledLinear(
attention_dim, embed_dim, bias=True, initial_scale=0.05 attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
) )
self.in_proj2 = nn.Linear(embed_dim, attention_dim, bias=False) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False)
self.out_proj2 = ScaledLinear(attention_dim, embed_dim, bias=True, self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True,
initial_scale=0.05) initial_scale=0.05)
# self.whiten_values2 is applied on the values in forward2() # self.whiten_values2 is applied on the values in forward2()
self.whiten_values2 = Whiten(num_groups=num_heads, self.whiten_values2 = Whiten(num_groups=num_heads,
@ -1005,10 +1005,8 @@ class RelPositionMultiheadAttention(nn.Module):
# self-attention # self-attention
q = x[...,:attention_dim] q, k, pv = x.chunk(3, dim=-1)
k = x[...,attention_dim:2*attention_dim] p, v = pv.chunk(2, dim=-1)
v = x[...,2*attention_dim:3*attention_dim]
p = x[...,3*attention_dim:]
k = self.whiten_keys(k) # does nothing in the forward pass. k = self.whiten_keys(k) # does nothing in the forward pass.
@ -1066,7 +1064,7 @@ class RelPositionMultiheadAttention(nn.Module):
q = q.reshape(seq_len, bsz, num_heads, head_dim) q = q.reshape(seq_len, bsz, num_heads, head_dim)
p = p.reshape(seq_len, bsz, num_heads, head_dim // 2) p = p.reshape(seq_len, bsz, num_heads, head_dim // 2)
k = k.reshape(seq_len, bsz, num_heads, head_dim) k = k.reshape(seq_len, bsz, num_heads, head_dim)
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1)
if key_padding_mask is not None: if key_padding_mask is not None:
@ -1140,11 +1138,11 @@ class RelPositionMultiheadAttention(nn.Module):
) )
attn_output = torch.bmm(attn_output_weights, v) attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim] assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2]
attn_output = ( attn_output = (
attn_output.transpose(0, 1) attn_output.transpose(0, 1)
.contiguous() .contiguous()
.view(seq_len, bsz, attention_dim) .view(seq_len, bsz, attention_dim // 2)
) )
attn_output = nn.functional.linear( attn_output = nn.functional.linear(
attn_output, out_proj_weight, out_proj_bias attn_output, out_proj_weight, out_proj_bias
@ -1173,9 +1171,9 @@ class RelPositionMultiheadAttention(nn.Module):
# v: (tgt_len, bsz, embed_dim // 2) # v: (tgt_len, bsz, embed_dim // 2)
v = self.in_proj2(x) v = self.in_proj2(x)
v = self.whiten_values2(v) # does nothing in the forward pass. v = self.whiten_values2(v) # does nothing in the forward pass.
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1)
# now v: (bsz * num_heads, seq_len, head_dim) # now v: (bsz * num_heads, seq_len, head_dim // 2)
attn_output = torch.bmm(attn_weights, v) attn_output = torch.bmm(attn_weights, v)
if random.random() < 0.001 or __name__ == "__main__": if random.random() < 0.001 or __name__ == "__main__":
@ -1185,7 +1183,7 @@ class RelPositionMultiheadAttention(nn.Module):
attn_output = ( attn_output = (
attn_output.transpose(0, 1) attn_output.transpose(0, 1)
.contiguous() .contiguous()
.view(seq_len, bsz, self.attention_dim) .view(seq_len, bsz, self.attention_dim // 2)
) )
# returned value is of shape (seq_len, bsz, embed_dim), like x. # returned value is of shape (seq_len, bsz, embed_dim), like x.
return self.out_proj2(attn_output) return self.out_proj2(attn_output)