diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index c2e505944..b27c22ef6 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -785,13 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + # attn_output_weights is of shape (bsz, num_heads, tgt_len, src_len) - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, -1 - ).permute(0, 2, 3, 1) + attn_output_weights = attn_output_weights.permute(0, 2, 3, 1) # now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads) attn_output_weights = torch.matmul(