mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Minor fixes after review.
This commit is contained in:
parent
dd2acd89fd
commit
4749619e5a
@ -785,13 +785,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
# attn_output_weights is of shape (bsz, num_heads, tgt_len, src_len)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, -1
|
||||
).permute(0, 2, 3, 1)
|
||||
attn_output_weights = attn_output_weights.permute(0, 2, 3, 1)
|
||||
# now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads)
|
||||
|
||||
attn_output_weights = torch.matmul(
|
||||
|
Loading…
x
Reference in New Issue
Block a user