diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 51ceac0f2..6b9c6e7a1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -818,18 +818,17 @@ def deprecated_greedy_search_batch_for_cross_attn( decoder_out = model.joiner.decoder_proj(decoder_out) encoder_out = model.joiner.encoder_proj(encoder_out) - encoder_out_for_attn = encoder_out.unsqueeze(2) - + # encoder_out_for_attn = encoder_out.unsqueeze(2) # decoder_out: (batch_size, 1, decoder_out_dim) for t in range(T): # current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) current_encoder_out = model.joiner.label_level_am_attention( - encoder_out, - decoder_out, - encoder_out_lens - ) + encoder_out[:, : t + 1, :].unsqueeze(2), + decoder_out.unsqueeze(2), + encoder_out_lens, + ) logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1),