From 4749619e5a1378a78293fb4b20ca20eee8e98d84 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 25 Jan 2022 18:46:35 +0800 Subject: [PATCH] Minor fixes after review. --- egs/librispeech/ASR/transducer_stateless/conformer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index c2e505944..b27c22ef6 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -785,13 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) + # attn_output_weights is of shape (bsz, num_heads, tgt_len, src_len) - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, -1 - ).permute(0, 2, 3, 1) + attn_output_weights = attn_output_weights.permute(0, 2, 3, 1) # now attn_output_weights is of shape (bsz, tgt_len, src_len, num_heads) attn_output_weights = torch.matmul(