start to modify pos enc

This commit is contained in:
yaozengwei 2022-06-28 16:17:49 +08:00
parent d2ea7d5de5
commit 6fa0ef1e8d

View File

@ -441,24 +441,7 @@ class RelPositionMultiheadAttention(nn.Module):
)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * self.pos_bias_u_scale.exp()
def _pos_bias_v(self):
return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None:
nn.init.normal_(self.pos_bias_u, std=0.01)
nn.init.normal_(self.pos_bias_v, std=0.01)
self.linear_pos = ScaledLinear(embed_dim, num_heads, bias=True)
def forward(
self,
@ -746,34 +729,23 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask.size(1), src_len
)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
q = q.permute(1, 2, 0, 3) # (batch, head, time1, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
q_with_bias_u = (q + self._pos_bias_u()).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self._pos_bias_v()).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
matrix_ac = torch.matmul(q, k) # (batch, head, time1, time2)
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
pos_emb = self.linear_pos(pos_emb) # (1, 2*time1-1, head)
matrix_bd = (
pos_emb.transpose(1, 2).unsqueeze(2).repeat(1, 1, tgt_len, 1)
) # (1, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd) # (1, head, time1, time2)
attn_output_weights = (
matrix_ac + matrix_bd