Minor fixes after review.

This commit is contained in:
Fangjun Kuang 2022-01-25 18:46:35 +08:00
parent dd2acd89fd
commit 4749619e5a

View File

@ -785,13 +785,9 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1
)
# attn_output_weights is of shape (bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, -1
).permute(0, 2, 3, 1)
attn_output_weights = attn_output_weights.permute(0, 2, 3, 1)
# now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads)
attn_output_weights = torch.matmul(