diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 9a06fbe9f..4c730c4ae 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -15,6 +15,8 @@ It uses pruned RNN-T. |------------------------|------|------|---------------------------------------| | greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 | | modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 | | fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 | Training command is: @@ -73,6 +75,78 @@ for epoch in 29; do done ``` +We provide the option of shallow fusion with a RNN language model. The pre-trained language model is +available at . To decode with the language model, +please use the following command: + +```bash +# download pre-trained model +git lfs install +git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 + +aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/ + +pushd ${aishell_exp}/exp +ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt +popd + +# download RNN LM +git lfs install +git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm +rnnlm_dir=icefall-aishell-rnn-lm + +# RNNLM shallow fusion +for lm_scale in $(seq 0.26 0.02 0.34); do + python ./pruned_transducer_stateless3/decode.py \ + --epoch 99 \ + --avg 1 \ + --lang-dir ${aishell_exp}/data/lang_char \ + --exp-dir ${aishell_exp}/exp \ + --use-averaged-model False \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir ${rnnlm_dir}/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 2 \ + --lm-vocab-size 4336 +done + +# RNNLM Low-order density ratio (LODR) with a 2-gram + +cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt + +for lm_scale in 0.48; do + for LODR_scale in -0.28; do + python ./pruned_transducer_stateless3/decode.py \ + --epoch 99 \ + --avg 1 \ + --lang-dir ${aishell_exp}/data/lang_char \ + --exp-dir ${aishell_exp}/exp \ + --use-averaged-model False \ + --decoding-method modified_beam_search_LODR \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir ${rnnlm_dir}/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 2 \ + --lm-vocab-size 4336 \ + --tokens-ngram 2 \ + --backoff-id 4336 \ + --ngram-lm-scale $LODR_scale + done +done + +``` + Pretrained models, training logs, decoding logs, and decoding results are available at diff --git a/egs/aishell/ASR/local/prepare_char_lm_training_data.py b/egs/aishell/ASR/local/prepare_char_lm_training_data.py old mode 100644 new mode 100755 diff --git a/egs/aishell/ASR/local/sort_lm_training_data.py b/egs/aishell/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..1d6ccbe33 --- /dev/null +++ b/egs/aishell/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index cf4ee7818..bd34c1f44 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -230,12 +230,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt fi - + + # training words ./local/prepare_char_lm_training_data.py \ --lang-char data/lang_char \ --lm-data $dl_dir/lm/aishell-train-word.txt \ --lm-archive $out_dir/lm_data.pt + # valid words if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid @@ -249,6 +251,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then --lm-data $dl_dir/lm/aishell-valid-word.txt \ --lm-archive $out_dir/lm_data_valid.pt + # test words if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid @@ -303,9 +306,9 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then --hidden-dim 512 \ --num-layers 2 \ --batch-size 400 \ - --exp-dir rnnlm_char/exp \ - --lm-data data/lm_training_char/sorted_lm_data.pt \ - --lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \ + --exp-dir rnnlm_char/exp_aishell1_small \ + --lm-data data/lm_char/sorted_lm_data_aishell1.pt \ + --lm-data-valid data/lm_char/sorted_lm_data_valid.pt \ --vocab-size 4336 \ --master-port 12345 fi diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index 954d9dc7e..27c64efaa 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -54,6 +54,40 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) modified beam search (with LM shallow fusion) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(6) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.48 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.28 \ """ @@ -74,9 +108,12 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -212,6 +249,60 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -223,6 +314,9 @@ def decode_one_batch( token_table: k2.SymbolTable, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -287,6 +381,24 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) else: hyp_tokens = [] batch_size = encoder_out.size(0) @@ -334,6 +446,9 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -379,6 +494,9 @@ def decode_dataset( token_table=token_table, decoding_graph=decoding_graph, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -445,6 +563,7 @@ def save_results( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -458,6 +577,8 @@ def main(): "beam_search", "fast_beam_search", "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -479,6 +600,19 @@ def main(): if params.use_averaged_model: params.suffix += "-use-averaged-model" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -588,6 +722,35 @@ def main(): else: decoding_graph = None + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + lm_filename, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -610,6 +773,9 @@ def main(): model=model, token_table=lexicon.token_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 b/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 deleted file mode 120000 index bcd4abc2f..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 +++ /dev/null @@ -1 +0,0 @@ -/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 6c58a57e1..73207017b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -550,7 +550,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -561,7 +560,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 999d793a4..75edf0c54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1863,7 +1863,6 @@ def modified_beam_search_LODR( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, LODR_lm: NgramLm, LODR_lm_scale: float, LM: LmScorer, @@ -1883,8 +1882,6 @@ def modified_beam_search_LODR( encoder_out_lens (torch.Tensor): A 1-D tensor of shape (N,), containing the number of valid frames in encoder_out before padding. - sp: - Sentence piece generator. LODR_lm: A low order n-gram LM, whose score will be subtracted during shallow fusion LODR_lm_scale: @@ -1912,7 +1909,7 @@ def modified_beam_search_LODR( ) blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") + sos_id = getattr(LM, "sos_id", 1) unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device @@ -2137,7 +2134,6 @@ def modified_beam_search_lm_shallow_fusion( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, LM: LmScorer, beam: int = 4, return_timestamps: bool = False, @@ -2176,7 +2172,7 @@ def modified_beam_search_lm_shallow_fusion( ) blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") + sos_id = getattr(LM, "sos_id", 1) unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index b39007dfc..7c62bfa58 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -675,7 +675,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -686,7 +685,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index af0b2d9fc..7a3e63218 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -586,7 +586,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -597,7 +596,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 576621e24..55a2493e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -533,7 +533,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -544,7 +543,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/wenetspeech/ASR/local/text2segments.py b/egs/wenetspeech/ASR/local/text2segments.py index df5b3c119..bdf5a3984 100644 --- a/egs/wenetspeech/ASR/local/text2segments.py +++ b/egs/wenetspeech/ASR/local/text2segments.py @@ -40,8 +40,8 @@ from tqdm import tqdm # and 'data()' is only supported in static graph mode. So if you # want to use this api, should call 'paddle.enable_static()' before # this api to enter static graph mode. -paddle.enable_static() -paddle.disable_signal_handler() +# paddle.enable_static() +# paddle.disable_signal_handler() jieba.enable_paddle() diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 50a00253d..f7b521794 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -261,3 +261,107 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then log "Stage 18: Compile LG" python ./local/compile_lg.py --lang-dir $lang_char_dir fi + +# prepare RNNLM data +if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then + log "Stage 19: Prepare LM training data" + + log "Processing char based data" + text_out_dir=data/lm_char + + mkdir -p $text_out_dir + + log "Genearating training text data" + + if [ ! -f $text_out_dir/lm_data.pt ]; then + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $lang_char_dir/text_words_segmentation \ + --lm-archive $text_out_dir/lm_data.pt + fi + + log "Generating DEV text data" + # prepare validation text data + if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then + valid_text=${text_out_dir}/ + + gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \ + | jq '.text' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $text_out_dir/valid_text + + python3 ./local/text2segments.py \ + --num-process $nj \ + --input-file $text_out_dir/valid_text \ + --output-file $text_out_dir/valid_text_words_segmentation + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $text_out_dir/valid_text_words_segmentation \ + --lm-archive $text_out_dir/lm_data_valid.pt + + # prepare TEST text data + if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then + log "Prepare text for test set." + for test_set in TEST_MEETING TEST_NET; do + gunzip -c data/manifests/wenetspeech_supervisions_${test_set}.jsonl.gz \ + | jq '.text' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $text_out_dir/${test_set}_text + + python3 ./local/text2segments.py \ + --num-process $nj \ + --input-file $text_out_dir/${test_set}_text \ + --output-file $text_out_dir/${test_set}_text_words_segmentation + done + + cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $text_out_dir/test_text_words_segmentation \ + --lm-archive $text_out_dir/lm_data_test.pt + +fi + +# sort RNNLM data +if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then + text_out_dir=data/lm_char + + log "Sort lm data" + + ./local/sort_lm_training_data.py \ + --in-lm-data $text_out_dir/lm_data.pt \ + --out-lm-data $text_out_dir/sorted_lm_data.pt \ + --out-statistics $text_out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $text_out_dir/lm_data_valid.pt \ + --out-lm-data $text_out_dir/sorted_lm_data-valid.pt \ + --out-statistics $text_out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $text_out_dir/lm_data_test.pt \ + --out-lm-data $text_out_dir/sorted_lm_data-test.pt \ + --out-statistics $text_out_dir/statistics-test.txt +fi + +export CUDA_VISIBLE_DEVICES="0,1" + +if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then + log "Stage 21: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 2 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data data/lm_char/sorted_lm_data.pt \ + --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12340 +fi \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index de12b2ff0..46ba6b005 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -2,6 +2,7 @@ # # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2022 Xiaomi Corporation (Author: Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -91,6 +92,22 @@ When training with the L subset, the streaming usage: --causal-convolution 1 \ --decode-chunk-size 16 \ --left-context 64 + +(4) modified beam search with RNNLM shallow fusion +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 """ @@ -111,9 +128,12 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -224,6 +244,16 @@ def get_parser(): Used only when --decoding-method is fast_beam_search""", ) + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -277,6 +307,50 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) add_model_arguments(parser) return parser @@ -288,6 +362,9 @@ def decode_one_batch( lexicon: Lexicon, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -374,6 +451,28 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) else: batch_size = encoder_out.size(0) @@ -419,6 +518,9 @@ def decode_dataset( model: nn.Module, lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -432,6 +534,8 @@ def decode_dataset( decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -449,7 +553,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 100 else: - log_interval = 2 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -463,6 +567,9 @@ def decode_dataset( lexicon=lexicon, decoding_graph=decoding_graph, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -524,6 +631,7 @@ def save_results( def main(): parser = get_parser() WenetSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -535,6 +643,8 @@ def main(): "beam_search", "fast_beam_search", "modified_beam_search", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -549,6 +659,22 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -558,6 +684,7 @@ def main(): logging.info(f"Device: {device}") + # import pdb; pdb.set_trace() lexicon = Lexicon(params.lang_dir) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 @@ -652,6 +779,37 @@ def main(): model.to(device) model.eval() model.device = device + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # import pdb; pdb.set_trace() + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + num_param = sum([p.numel() for p in LM.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + else: + LM = None if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) @@ -684,6 +842,9 @@ def main(): model=model, lexicon=lexicon, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( params=params, diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py index 0468befd0..5e2783a47 100644 --- a/icefall/lm_wrapper.py +++ b/icefall/lm_wrapper.py @@ -50,7 +50,7 @@ class LmScorer(torch.nn.Module): def add_arguments(cls, parser): # LM general arguments parser.add_argument( - "--vocab-size", + "--lm-vocab-size", type=int, default=500, ) diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py index f75a89590..cc566bd92 100755 --- a/icefall/rnn_lm/compute_perplexity.py +++ b/icefall/rnn_lm/compute_perplexity.py @@ -33,7 +33,7 @@ import torch from dataset import get_dataloader from model import RnnLmModel -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, setup_logger, str2bool @@ -49,6 +49,7 @@ def get_parser(): help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) + parser.add_argument( "--avg", type=int, @@ -58,6 +59,16 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -154,7 +165,14 @@ def main(): params = AttributeDict(vars(args)) - setup_logger(f"{params.exp_dir}/log-ppl/") + if params.iter > 0: + setup_logger( + f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}" + ) + else: + setup_logger( + f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}" + ) logging.info("Computing perplexity started") logging.info(params) @@ -173,19 +191,39 @@ def main(): tie_weights=params.tie_weights, ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 filenames = [] for i in range(start, params.epoch + 1): - if start >= 0: + if i >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) num_param_requires_grad = sum( diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index 2411cb1f0..a8598a1ce 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -25,7 +25,7 @@ from pathlib import Path import torch from model import RnnLmModel -from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, load_averaged_model, str2bool @@ -51,6 +51,16 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--vocab-size", type=int, @@ -133,11 +143,36 @@ def main(): model.to(device) - if params.avg == 1: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: - model = load_averaged_model( - params.exp_dir, model, params.epoch, params.avg, device + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False ) model.to("cpu") diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index f43e66cd2..91df4f921 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -49,6 +49,7 @@ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -178,6 +179,33 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--lr", + type=float, + default=1e-3, + ) + + parser.add_argument( + "--max-sent-len", + type=int, + default=200, + help="""Maximum number of tokens in a sentence. This is used + to adjust batch-size dynamically""", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + return parser @@ -190,16 +218,15 @@ def get_params() -> AttributeDict: "sos_id": 1, "eos_id": 1, "blank_id": 0, - "lr": 1e-3, "weight_decay": 1e-6, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 200, + "log_interval": 100, "reset_interval": 2000, - "valid_interval": 5000, + "valid_interval": 200, "env_info": get_env_info(), } ) @@ -382,6 +409,7 @@ def train_one_epoch( valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -430,6 +458,19 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + rank=rank, + ) + if batch_idx % params.log_interval == 0: # Note: "frames" here means "num_tokens" this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) @@ -580,6 +621,7 @@ def run(rank, world_size, args): valid_dl=valid_dl, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) save_checkpoint(