simplify relative position encoding

This commit is contained in:
yaozengwei 2022-07-31 20:31:02 +08:00
parent b338471917
commit 5313ce00d6

View File

@ -859,7 +859,7 @@ class RelPositionalEncoding(torch.nn.Module):
class RelPositionMultiheadAttention(nn.Module): class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding r"""Multi-Head Attention layer with simplified relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
@ -895,24 +895,7 @@ class RelPositionMultiheadAttention(nn.Module):
) )
# linear transformation for positional encoding. # linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) self.linear_pos = ScaledLinear(embed_dim, num_heads, bias=True)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * self.pos_bias_u_scale.exp()
def _pos_bias_v(self):
return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None:
nn.init.normal_(self.pos_bias_u, std=0.01)
nn.init.normal_(self.pos_bias_v, std=0.01)
def forward( def forward(
self, self,
@ -1217,35 +1200,25 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask.size(1), src_len key_padding_mask.size(1), src_len
) )
q = q.transpose(0, 1) # (batch, time1, head, d_k) q = q.permute(1, 2, 0, 3) # (batch, head, time1, d_k)
pos_emb_bsz = pos_emb.size(0) pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1 assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self._pos_bias_u()).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self._pos_bias_v()).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score # compute attention score
# first compute matrix a and matrix c # first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul( matrix_ac = torch.matmul(q, k) # (batch, head, time1, time2)
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d # compute matrix b and matrix d
matrix_bd = torch.matmul( pos_emb = self.linear_pos(pos_emb) # (1, 2*time1-1, head)
q_with_bias_v, p matrix_bd = (
) # (batch, head, time1, 2*time1-1) pos_emb.transpose(1, 2).unsqueeze(2).repeat(1, 1, tgt_len, 1)
matrix_bd = self.rel_shift(matrix_bd, left_context) ) # (1, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(
matrix_bd, left_context
) # (1, head, time1, time2)
attn_output_weights = ( attn_output_weights = (
matrix_ac + matrix_bd matrix_ac + matrix_bd