diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index cb48a33f1..989caa802 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -267,6 +267,9 @@ def beam_search( sym_per_utt = 0 + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + decoder_cache: Dict[str, torch.Tensor] = {} while t < T and sym_per_utt < max_sym_per_utt: @@ -299,7 +302,12 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner(current_encoder_out, decoder_out) + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len, + decoder_out_len, + ) # TODO(fangjun): Ccale the blank posterior