diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index f78b99fbd..754971cfb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1219,8 +1219,10 @@ def fast_beam_search_with_nbest_rescoring( temperature: float = 1.0, ) -> Dict[str, List[List[int]]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using modified beam search, and then - the shortest path within the lattice is used as the final output. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model. The shortest path within the + lattice is used as the final output. + Args: model: An instance of `Transducer`. @@ -1366,14 +1368,17 @@ def fast_beam_search_with_nbest_rnn_rescoring( sp: spm.SentencePieceProcessor, word_table: k2.SymbolTable, rnn_lm_model: torch.nn.Module, + rnn_lm_scale_list: List[float], oov_word: str = "", use_double_scores: bool = True, nbest_scale: float = 0.5, temperature: float = 1.0, ) -> Dict[str, List[List[int]]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using modified beam search, and then - the shortest path within the lattice is used as the final output. + A lattice is first obtained using fast beam search, num_path are selected + and rescored using a given language model and a rnn-lm. + The shortest path within the lattice is used as the final output. + Args: model: An instance of `Transducer`. @@ -1400,6 +1405,10 @@ def fast_beam_search_with_nbest_rnn_rescoring( The BPE model. word_table: The word symbol table. + rnn_lm_model: + A rnn-lm model used for LM rescoring + rnn_lm_scale_list: + A list of floats representing RNN score scales. oov_word: OOV words are replaced with this word. use_double_scores: @@ -1501,9 +1510,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( sos_tokens_row_splits = sos_tokens.shape.row_splits(1) sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - x_tokens = sos_tokens.pad( - mode="constant", padding_value=model.decoder.blank_id - ) + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) x_tokens = x_tokens.to(torch.int64) @@ -1517,7 +1524,7 @@ def fast_beam_search_with_nbest_rnn_rescoring( ans: Dict[str, List[List[int]]] = {} for n_scale in ngram_lm_scale_list: - for rnn_scale in ngram_lm_scale_list: + for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( am_scores.values diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index c8d709eba..c3a03f2e1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -605,6 +605,7 @@ def decode_one_batch( sp=sp, word_table=word_table, rnn_lm_model=rnn_lm_model, + rnn_lm_scale_list=ngram_lm_scale_list, use_double_scores=True, nbest_scale=params.nbest_scale, temperature=params.temperature, diff --git a/icefall/decode.py b/icefall/decode.py index 680e29619..e596876f4 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -1006,6 +1006,8 @@ def rescore_with_rnn_lm( An FsaVec with axes [utt][state][arc]. num_paths: Number of paths to extract from the given lattice for rescoring. + rnn_lm_model: + A rnn-lm model used for LM rescoring model: A transformer model. See the class "Transformer" in conformer_ctc/transformer.py for its interface.