From 949b53274cc6e57200c41b49b22c9d3947a7738a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 12 Mar 2022 15:43:16 +0800 Subject: [PATCH] Minor fixes. --- .../ASR/pruned_transducer_stateless/beam_search.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 5815fefa5..38ab16507 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -262,7 +262,7 @@ def modified_beam_search( for t in range(T): # fmt: off current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, encoder_out_dim) + # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) # fmt: on A = list(B) B = HypothesisList() @@ -278,11 +278,11 @@ def modified_beam_search( # decoder_input is of shape (num_hyps, context_size) decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - # decoder_output is of shape (num_hyps, 1,1, decoder_output_dim) + # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) current_encoder_out = current_encoder_out.expand( decoder_out.size(0), 1, 1, -1 - ) + ) # (num_hyps, 1, 1, encoder_out_dim) logits = model.joiner( current_encoder_out,