diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 5815fefa5..38ab16507 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -262,7 +262,7 @@ def modified_beam_search( for t in range(T): # fmt: off current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, encoder_out_dim) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) # fmt: on A = list(B) B = HypothesisList() @@ -278,11 +278,11 @@ def modified_beam_search( # decoder_input is of shape (num_hyps, context_size) decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - # decoder_output is of shape (num_hyps, 1,1, decoder_output_dim) + # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) current_encoder_out = current_encoder_out.expand( decoder_out.size(0), 1, 1, -1 - ) + ) # (num_hyps, 1, 1, encoder_out_dim) logits = model.joiner( current_encoder_out,