From 9865cd379f3037f73956edde2529c7651c9a87d4 Mon Sep 17 00:00:00 2001 From: Ryuk Date: Sun, 25 Aug 2024 16:25:09 +0800 Subject: [PATCH] support transformer-lm when using lm rescoring --- .../ASR/pruned_transducer_stateless2/beam_search.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 66c84b2a9..9e761e10c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1576,7 +1576,12 @@ def modified_beam_search_lm_rescore( y = y.to(device).to(torch.int64) sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + if LM.lm_type == "rnn": + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + else: + lm_scores = LM.lm(x=x, y=y, x_lens=sentence_token_lengths) + lm_scores = lm_scores.reshape(x.size(0), -1) + assert lm_scores.ndim == 2 lm_scores = -1 * lm_scores.sum(dim=1)