From 9d2ac7b1ec981bfa19697b613a93589730b9c311 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Mon, 14 Aug 2023 21:05:56 +0800 Subject: [PATCH] minor fixes --- .../zipformer_label_level_algn/alignment_attention_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py index 86a2356b9..c9afe4e9e 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py @@ -374,9 +374,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_scores = pos_scores.reshape(-1, n) pos_scores = torch.gather(pos_scores, dim=1, index=indexes) pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len) - else: + elif not for_reference: pos_scores = pos_scores.as_strided( - (num_heads, b_p_dim, lm_seq_len, lm_seq_len), + (num_heads, b_p_dim, lm_seq_len, am_seq_len), ( pos_scores.stride(0), pos_scores.stride(1),