diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index f1d16749c..743fd2c46 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1579,7 +1579,12 @@ def modified_beam_search_lm_rescore( y = y.to(device).to(torch.int64) sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + if LM.lm_type == "rnn": + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + else: + lm_scores = LM.lm(x=x, y=y, x_lens=sentence_token_lengths) + lm_scores = lm_scores.reshape(x.size(0), -1) + assert lm_scores.ndim == 2 lm_scores = -1 * lm_scores.sum(dim=1)