Modified attention.

This commit is contained in:
Fangjun Kuang 2022-01-25 17:43:04 +08:00
parent 8e6fd97c6b
commit dd2acd89fd

View File

@ -449,6 +449,12 @@ class RelPositionMultiheadAttention(nn.Module):
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
# Before applying the softmax to the matrix from the dot-product,
# we multiply that matrix with this one
self.pre_softmax_param = nn.Parameter(
torch.Tensor(num_heads, num_heads)
)
self._reset_parameters()
def _reset_parameters(self) -> None:
@ -459,6 +465,9 @@ class RelPositionMultiheadAttention(nn.Module):
nn.init.xavier_uniform_(self.pos_bias_u)
nn.init.xavier_uniform_(self.pos_bias_v)
stdv = 1.0 / math.sqrt(self.num_heads)
nn.init.normal_(self.pre_softmax_param, mean=0, std=stdv)
def forward(
self,
query: Tensor,
@ -780,6 +789,19 @@ class RelPositionMultiheadAttention(nn.Module):
bsz * num_heads, tgt_len, -1
)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, -1
).permute(0, 2, 3, 1)
# now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads)
attn_output_weights = torch.matmul(
attn_output_weights, self.pre_softmax_param
)
attn_output_weights = attn_output_weights.permute(0, 3, 1, 2).reshape(
bsz * num_heads, tgt_len, -1
)
assert list(attn_output_weights.size()) == [
bsz * num_heads,
tgt_len,