mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Use half the dim for values, vs. keys and queries.
This commit is contained in:
parent
3f495cd197
commit
2675944f01
@ -843,7 +843,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
# the initial_scale is supposed to take over the "scaling" factor of
|
||||
# head_dim ** -0.5, dividing it between the query and key.
|
||||
self.in_proj = ScaledLinear(embed_dim, 7 * attention_dim // 2, bias=True,
|
||||
self.in_proj = ScaledLinear(embed_dim, 3 * attention_dim, bias=True,
|
||||
initial_scale=self.head_dim**-0.25)
|
||||
|
||||
# self.whiten_values is applied on the values in forward()
|
||||
@ -863,14 +863,14 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.linear_pos = ScaledLinear(embed_dim, attention_dim // 2, bias=False,
|
||||
initial_scale=0.05)
|
||||
|
||||
self.in_balancer = ActivationBalancer(7 * attention_dim // 2,
|
||||
self.in_balancer = ActivationBalancer(3 * attention_dim,
|
||||
channel_dim=-1, max_abs=5.0)
|
||||
self.out_proj = ScaledLinear(
|
||||
attention_dim, embed_dim, bias=True, initial_scale=0.05
|
||||
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
|
||||
self.in_proj2 = nn.Linear(embed_dim, attention_dim, bias=False)
|
||||
self.out_proj2 = ScaledLinear(attention_dim, embed_dim, bias=True,
|
||||
self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False)
|
||||
self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True,
|
||||
initial_scale=0.05)
|
||||
# self.whiten_values2 is applied on the values in forward2()
|
||||
self.whiten_values2 = Whiten(num_groups=num_heads,
|
||||
@ -1005,10 +1005,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
|
||||
# self-attention
|
||||
q = x[...,:attention_dim]
|
||||
k = x[...,attention_dim:2*attention_dim]
|
||||
v = x[...,2*attention_dim:3*attention_dim]
|
||||
p = x[...,3*attention_dim:]
|
||||
q, k, pv = x.chunk(3, dim=-1)
|
||||
p, v = pv.chunk(2, dim=-1)
|
||||
|
||||
|
||||
k = self.whiten_keys(k) # does nothing in the forward pass.
|
||||
@ -1066,7 +1064,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
q = q.reshape(seq_len, bsz, num_heads, head_dim)
|
||||
p = p.reshape(seq_len, bsz, num_heads, head_dim // 2)
|
||||
k = k.reshape(seq_len, bsz, num_heads, head_dim)
|
||||
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1)
|
||||
|
||||
|
||||
if key_padding_mask is not None:
|
||||
@ -1140,11 +1138,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim]
|
||||
assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(seq_len, bsz, attention_dim)
|
||||
.view(seq_len, bsz, attention_dim // 2)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
@ -1173,9 +1171,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
# v: (tgt_len, bsz, embed_dim // 2)
|
||||
v = self.in_proj2(x)
|
||||
v = self.whiten_values2(v) # does nothing in the forward pass.
|
||||
v = v.reshape(seq_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1)
|
||||
|
||||
# now v: (bsz * num_heads, seq_len, head_dim)
|
||||
# now v: (bsz * num_heads, seq_len, head_dim // 2)
|
||||
attn_output = torch.bmm(attn_weights, v)
|
||||
|
||||
if random.random() < 0.001 or __name__ == "__main__":
|
||||
@ -1185,7 +1183,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(seq_len, bsz, self.attention_dim)
|
||||
.view(seq_len, bsz, self.attention_dim // 2)
|
||||
)
|
||||
# returned value is of shape (seq_len, bsz, embed_dim), like x.
|
||||
return self.out_proj2(attn_output)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user