Minor fixes after review.

This commit is contained in:
Fangjun Kuang 2022-01-25 18:46:35 +08:00
parent dd2acd89fd
commit 4749619e5a

View File

@ -785,13 +785,9 @@ class RelPositionMultiheadAttention(nn.Module):
matrix_ac + matrix_bd matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2) ) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( # attn_output_weights is of shape (bsz, num_heads, tgt_len, src_len)
bsz * num_heads, tgt_len, -1
)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.permute(0, 2, 3, 1)
bsz, num_heads, tgt_len, -1
).permute(0, 2, 3, 1)
# now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads) # now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads)
attn_output_weights = torch.matmul( attn_output_weights = torch.matmul(