Use half the dim per head, in self_attn layers.
This commit is contained in:
parent
ce3f59d9c7
commit
71b3756ada
@ -469,27 +469,27 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // (num_heads * 2)
|
||||||
assert (
|
assert (
|
||||||
self.head_dim * num_heads == self.embed_dim
|
self.head_dim * num_heads == self.embed_dim // 2
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
||||||
self.in_balancer = ActivationBalancer(3 * embed_dim,
|
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
|
||||||
channel_dim=-1, max_abs=5.0,
|
channel_dim=-1, max_abs=5.0,
|
||||||
max_var_per_eig=0.2)
|
max_var_per_eig=0.2)
|
||||||
self.proj_balancer = ActivationBalancer(embed_dim,
|
self.proj_balancer = ActivationBalancer(embed_dim // 2,
|
||||||
channel_dim=-1, max_abs=10.0,
|
channel_dim=-1, max_abs=10.0,
|
||||||
min_positive=0.0, max_positive=1.0)
|
min_positive=0.0, max_positive=1.0)
|
||||||
self.out_proj = ScaledLinear(
|
self.out_proj = ScaledLinear(
|
||||||
embed_dim, embed_dim, bias=True, initial_scale=0.5
|
embed_dim // 2, embed_dim, bias=True, initial_scale=0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
self.attn_scores_proj_in = nn.Parameter(torch.eye(num_heads))
|
||||||
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
self.attn_scores_proj_out = nn.Parameter(torch.zeros(num_heads, num_heads))
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = nn.Linear(embed_dim, embed_dim // 2, bias=False)
|
||||||
# these two learnable bias are used in matrix c and matrix d
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
@ -663,9 +663,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
tgt_len, bsz, _ = x.size()
|
tgt_len, bsz, _ = x.size()
|
||||||
|
|
||||||
head_dim = embed_dim // num_heads
|
head_dim = embed_dim // (num_heads * 2)
|
||||||
assert (
|
assert (
|
||||||
head_dim * num_heads == embed_dim
|
head_dim * num_heads == embed_dim // 2
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
scaling = float(head_dim) ** -0.5
|
scaling = float(head_dim) ** -0.5
|
||||||
@ -815,7 +815,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
attn_output = (
|
attn_output = (
|
||||||
attn_output.transpose(0, 1)
|
attn_output.transpose(0, 1)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
.view(tgt_len, bsz, embed_dim)
|
.view(tgt_len, bsz, embed_dim // 2)
|
||||||
)
|
)
|
||||||
attn_output = nn.functional.linear(
|
attn_output = nn.functional.linear(
|
||||||
attn_output, out_proj_weight, out_proj_bias
|
attn_output, out_proj_weight, out_proj_bias
|
||||||
|
|||||||
Reference in New Issue
Block a user