diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 9ed9b2ad1..cb48a33f1 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -66,6 +66,9 @@ def greedy_search( # symbols per utterance decoded so far sym_per_utt = 0 + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + while t < T and sym_per_utt < max_sym_per_utt: if sym_per_frame >= max_sym_per_frame: sym_per_frame = 0 @@ -75,7 +78,9 @@ def greedy_search( # fmt: off current_encoder_out = encoder_out[:, t:t+1, :] # fmt: on - logits = model.joiner(current_encoder_out, decoder_out) + logits = model.joiner( + current_encoder_out, decoder_out, encoder_out_len, decoder_out_len + ) # logits is (1, 1, 1, vocab_size) y = logits.argmax().item()