diff --git a/README.md b/README.md index fcba0723b..7213d8460 100644 --- a/README.md +++ b/README.md @@ -257,8 +257,8 @@ We provide some models for this recipe: [Pruned stateless RNN-T_2: Conformer enc | | Dev | Test-Net | Test-Meeting | |----------------------|-------|----------|--------------| | greedy search | 7.80 | 8.75 | 13.49 | +| modified beam search| 7.76 | 8.71 | 13.41 | | fast beam search | 7.94 | 8.74 | 13.80 | -| modified beam search | 7.76 | 8.71 | 13.41 | #### Pruned stateless RNN-T_5: Conformer encoder + Embedding decoder + k2 pruned RNN-T loss (trained with L subset) **Streaming**: diff --git a/egs/wenetspeech/ASR/RESULTS.md b/egs/wenetspeech/ASR/RESULTS.md index cc36ae4f2..658ad4a9b 100644 --- a/egs/wenetspeech/ASR/RESULTS.md +++ b/egs/wenetspeech/ASR/RESULTS.md @@ -84,7 +84,10 @@ When training with the L subset, the CERs are |------------------------------------|-------|----------|--------------|------------------------------------------| | greedy search | 7.80 | 8.75 | 13.49 | --epoch 10, --avg 2, --max-duration 100 | | modified beam search (beam size 4) | 7.76 | 8.71 | 13.41 | --epoch 10, --avg 2, --max-duration 100 | -| fast beam search (set as default) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 | +| fast beam search (1best) | 7.94 | 8.74 | 13.80 | --epoch 10, --avg 2, --max-duration 1500 | +| fast beam search (nbest) | 9.82 | 10.98 | 16.37 | --epoch 10, --avg 2, --max-duration 600 | +| fast beam search (nbest oracle) | 6.88 | 7.18 | 11.77 | --epoch 10, --avg 2, --max-duration 600 | +| fast beam search (nbest LG, ngram_lm_scale=0.35) | 8.83 | 9.88 | 15.47 | --epoch 10, --avg 2, --max-duration 600 | The training command for reproducing is given below: @@ -131,7 +134,7 @@ avg=2 --decoding-method modified_beam_search \ --beam-size 4 -## fast beam search +## fast beam search (1best) ./pruned_transducer_stateless2/decode.py \ --epoch $epoch \ --avg $avg \ @@ -142,6 +145,47 @@ avg=2 --beam 4 \ --max-contexts 4 \ --max-states 8 + +## fast beam search (nbest) +./pruned_transducer_stateless2/decode.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +## fast beam search (nbest oracle WER) +./pruned_transducer_stateless2/decode.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +## fast beam search (with LG) +./pruned_transducer_stateless2/decode.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --ngram-lm-scale 0.35 \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 ``` When training with the M subset, the CERs are diff --git a/egs/wenetspeech/ASR/local/compile_lg.py b/egs/wenetspeech/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/wenetspeech/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 9449e5d1e..6573a94ad 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -225,3 +225,34 @@ if [ $stage -le 16 ] && [ $stop_stage -ge 16 ]; then --lang-dir data/lang_char fi fi + +# If you don't want to use LG for decoding, the following steps are not necessary. +if [ $stage -le 17 ] && [ $stop_stage -ge 17 ]; then + log "Stage 17: Prepare G" + # It will take about 20 minutes. + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + lang_char_dir=data/lang_char + if [ ! -f $lang_char_dir/3-gram.unpruned.arpa ]; then + python ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_char_dir/text_words_segmentation \ + -lm $lang_char_dir/3-gram.unpruned.arpa + fi + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building LG + python3 -m kaldilm \ + --read-symbol-table="$lang_char_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt + fi +fi + +if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then + log "Stage 18: Compile LG" + lang_char_dir=data/lang_char + python ./local/compile_lg.py --lang-dir $lang_char_dir +fi diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 41e7a0f44..7c06cdb3d 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -37,7 +37,7 @@ When training with the L subset, usage: --decoding-method modified_beam_search \ --beam-size 4 -(3) fast beam search +(3) fast beam search (1best) ./pruned_transducer_stateless2/decode.py \ --epoch 10 \ --avg 2 \ @@ -48,6 +48,46 @@ When training with the L subset, usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(4) fast beam search (nbest) +./pruned_transducer_stateless2/decode.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(5) fast beam search (nbest oracle WER) +./pruned_transducer_stateless2/decode.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (with LG) +./pruned_transducer_stateless2/decode.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -63,6 +103,9 @@ import torch.nn as nn from asr_datamodule import WenetSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -70,6 +113,7 @@ from beam_search import ( ) from train import get_params, get_transducer_model +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, find_checkpoints, @@ -151,6 +195,11 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to + specify `--lang-dir`, which should contain `LG.pt`. """, ) @@ -173,6 +222,16 @@ def get_parser(): Used only when --decoding-method is fast_beam_search""", ) + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.35, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -204,6 +263,24 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + return parser @@ -211,6 +288,7 @@ def decode_one_batch( params: AttributeDict, model: nn.Module, lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: @@ -267,6 +345,50 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -331,6 +453,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -373,6 +496,7 @@ def decode_dataset( params=params, model=model, lexicon=lexicon, + graph_compiler=graph_compiler, decoding_graph=decoding_graph, batch=batch, ) @@ -454,6 +578,9 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -463,6 +590,13 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + if params.decoding_method == "fast_beam_search_nbest_LG": + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if ( + params.decoding_method == "fast_beam_search_nbest" + or params.decoding_method == "fast_beam_search_nbest_oracle" + ): + params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: @@ -482,6 +616,11 @@ def main(): params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + logging.info(params) logging.info("About to create model") @@ -513,8 +652,18 @@ def main(): model.eval() model.device = device - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lg_filename = params.lang_dir + "/LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None @@ -610,6 +759,7 @@ def main(): params=params, model=model, lexicon=lexicon, + graph_compiler=graph_compiler, decoding_graph=decoding_graph, ) save_results(