diff --git a/egs/aishell2/ASR/RESULTS.md b/egs/aishell2/ASR/RESULTS.md index 6983cc80d..6035fd4cf 100644 --- a/egs/aishell2/ASR/RESULTS.md +++ b/egs/aishell2/ASR/RESULTS.md @@ -12,7 +12,10 @@ When training with context size equals to 1, the WERs are |------------------------------------|-------|----------|----------------------------------| | greedy search | 5.57 | 5.89 | --epoch 25, --avg 5, --max-duration 600 | | modified beam search (beam size 4) | 5.32 | 5.56 | --epoch 25, --avg 5, --max-duration 600 | -| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 | +| fast beam search (set as default) | 5.5 | 5.78 | --epoch 25, --avg 5, --max-duration 600 | +| fast beam search nbest | 5.46 | 5.74 | --epoch 25, --avg 5, --max-duration 600 | +| fast beam search oracle | 1.92 | 2.2 | --epoch 25, --avg 5, --max-duration 600 | +| fast beam search nbest LG | 5.59 | 5.93 | --epoch 25, --avg 5, --max-duration 600 | The training command for reproducing is given below: @@ -37,11 +40,13 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" The decoding command is: ``` -for method in greedy_search modified_beam_search fast_beam_search; do +for method in greedy_search modified_beam_search \ + fast_beam_search fast_beam_search_nbest \ + fast_beam_search_nbest_oracle fast_beam_search_nbest_LG; do ./pruned_transducer_stateless5/decode.py \ --epoch 25 \ --avg 5 \ - --exp-dir /result \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method $method \ --max-sym-per-frame 1 \ @@ -51,7 +56,13 @@ for method in greedy_search modified_beam_search fast_beam_search; do --encoder-dim 384 \ --decoder-dim 512 \ --joiner-dim 512 \ - --use-averaged-model True + --context-size 1 \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 \ + --use-averaged-model False done ``` The tensorboard training log can be found at @@ -66,6 +77,9 @@ When training with context size equals to 2, the WERs are | greedy search | 5.47 | 5.81 | --epoch 25, --avg 5, --max-duration 600 | | modified beam search (beam size 4) | 5.38 | 5.61 | --epoch 25, --avg 5, --max-duration 600 | | fast beam search (set as default) | 5.36 | 5.61 | --epoch 25, --avg 5, --max-duration 600 | +| fast beam search nbest | 5.37 | 5.6 | --epoch 25, --avg 5, --max-duration 600 | +| fast beam search oracle | 2.04 | 2.2 | --epoch 25, --avg 5, --max-duration 600 | +| fast beam search nbest LG | 5.59 | 5.82 | --epoch 25, --avg 5, --max-duration 600 | The tensorboard training log can be found at https://tensorboard.dev/experiment/5AxJ8LHoSre8kDAuLp4L7Q/#scalars diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 201ee2947..06810bfdd 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -112,9 +112,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi fi +lang_char_dir=data/lang_char if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare char based lang" - lang_char_dir=data/lang_char mkdir -p $lang_char_dir # Prepare text. @@ -151,3 +151,31 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then python3 ./local/prepare_char.py fi fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + if [ ! -f ${lang_char_dir}/3-gram.unpruned.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_char_dir/text_words_segmentation \ + -lm $lang_char_dir/3-gram.unpruned.arpa + fi + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building LG + python3 -m kaldilm \ + --read-symbol-table="$lang_char_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $lang_char_dir/3-gram.unpruned.arpa > data/lm/G_3_gram.fst.txt + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compile LG" + ./local/compile_lg.py --lang-dir $lang_char_dir +fi diff --git a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py index bdaba4c5b..46206d819 100755 --- a/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/aishell2/ASR/pruned_transducer_stateless5/decode.py @@ -123,6 +123,7 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -306,6 +307,7 @@ def decode_one_batch( params: AttributeDict, model: nn.Module, lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, batch: dict, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: @@ -376,7 +378,8 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in hyp_tokens: - hyps.append([lexicon.word_table[i] for i in hyp]) + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) elif params.decoding_method == "fast_beam_search_nbest": hyp_tokens = fast_beam_search_nbest( model=model, @@ -401,7 +404,7 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, num_paths=params.num_paths, - ref_texts=supervisions["text"], + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), nbest_scale=params.nbest_scale, ) for i in range(encoder_out.size(0)): @@ -473,6 +476,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -515,6 +519,7 @@ def decode_dataset( params=params, model=model, lexicon=lexicon, + graph_compiler=graph_compiler, decoding_graph=decoding_graph, batch=batch, ) @@ -642,6 +647,11 @@ def main(): params.unk_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + logging.info(params) logging.info("About to create model") @@ -728,7 +738,18 @@ def main(): model.eval() if "fast_beam_search" in params.decoding_method: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None @@ -753,6 +774,7 @@ def main(): params=params, model=model, lexicon=lexicon, + graph_compiler=graph_compiler, decoding_graph=decoding_graph, ) diff --git a/egs/wenetspeech/ASR/local/prepare_words.py b/egs/wenetspeech/ASR/local/prepare_words.py index 65aca2983..d5f833db1 100644 --- a/egs/wenetspeech/ASR/local/prepare_words.py +++ b/egs/wenetspeech/ASR/local/prepare_words.py @@ -75,6 +75,16 @@ def main(): logging.info("Starting writing the words.txt") f_out = open(output_file, "w", encoding="utf-8") + + # LG decoding needs below symbols. + id1, id2, id3 = ( + str(len(new_lines)), + str(len(new_lines) + 1), + str(len(new_lines) + 2), + ) + add_words = ["#0 " + id1, " " + id2, " " + id3] + new_lines.extend(add_words) + for line in new_lines: f_out.write(line) f_out.write("\n")