From b9a8f107a26f72d8f468c2a8bf129f19ed696c19 Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Tue, 25 Jul 2023 23:51:21 +0800 Subject: [PATCH] Update beam_search.py --- .../ASR/pruned_transducer_stateless2/beam_search.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 51ceac0f2..6b9c6e7a1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -818,18 +818,17 @@ def deprecated_greedy_search_batch_for_cross_attn( decoder_out = model.joiner.decoder_proj(decoder_out) encoder_out = model.joiner.encoder_proj(encoder_out) - encoder_out_for_attn = encoder_out.unsqueeze(2) - + # encoder_out_for_attn = encoder_out.unsqueeze(2) # decoder_out: (batch_size, 1, decoder_out_dim) for t in range(T): # current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) current_encoder_out = model.joiner.label_level_am_attention( - encoder_out, - decoder_out, - encoder_out_lens - ) + encoder_out[:, : t + 1, :].unsqueeze(2), + decoder_out.unsqueeze(2), + encoder_out_lens, + ) logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1),