diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 66c84b2a9..9e761e10c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1576,7 +1576,12 @@ def modified_beam_search_lm_rescore( y = y.to(device).to(torch.int64) sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + if LM.lm_type == "rnn": + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + else: + lm_scores = LM.lm(x=x, y=y, x_lens=sentence_token_lengths) + lm_scores = lm_scores.reshape(x.size(0), -1) + assert lm_scores.ndim == 2 lm_scores = -1 * lm_scores.sum(dim=1)