From 4b5bc480e8a5ac253dcd22b08dfa59083dadd6fd Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 30 Nov 2022 17:26:05 +0800 Subject: [PATCH] Add low-order density ratio in RNNLM shallow fusion (#678) * Support LODR in RNNLM shallow fusion * fix style * fix code style * update workflow and CI * update results * propagate changes to stateless3 * add decoding results for stateless3+giga * fix CI --- ...-lstm-transducer-stateless2-2022-09-03.yml | 67 ++++- ...-lstm-transducer-stateless2-2022-09-03.yml | 15 +- egs/librispeech/ASR/RESULTS.md | 87 ++++++ .../ASR/lstm_transducer_stateless2/decode.py | 51 +++- .../beam_search.py | 264 ++++++++++++++++++ .../pruned_transducer_stateless3/decode.py | 181 +++++++++++- 6 files changed, 646 insertions(+), 19 deletions(-) diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 6ce92d022..ac5b15979 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -16,6 +16,7 @@ log "Downloading pre-trained model from $repo_url" git lfs install git clone $repo_url repo=$(basename $repo_url) +abs_repo=$(realpath $repo) log "Display test files" tree $repo/ @@ -178,21 +179,27 @@ echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm log "Download pre-trained RNN-LM model from ${lm_repo_url}" - git clone $lm_repo_url + GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url lm_repo=$(basename $lm_repo_url) pushd $lm_repo git lfs pull --include "exp/pretrained.pt" - cd exp - ln -s pretrained.pt epoch-88.pt + mv exp/pretrained.pt exp/epoch-88.pt popd + mkdir -p lstm_transducer_stateless2/exp + ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh lstm_transducer_stateless2/exp + + log "Decoding test-clean and test-other" + ./lstm_transducer_stateless2/decode.py \ --use-averaged-model 0 \ - --epoch 99 \ + --epoch 999 \ --avg 1 \ - --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_bpe_500 \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method modified_beam_search_rnnlm_shallow_fusion \ --beam 4 \ @@ -204,6 +211,52 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then --rnn-lm-tie-weights 1 fi +if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then + bigram_repo_url=https://huggingface.co/marcoyang/librispeech_bigram + log "Download bi-gram LM from ${bigram_repo_url}" + GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url + bigramlm_repo=$(basename $bigram_repo_url) + pushd $bigramlm_repo + git lfs pull --include "2gram.fst.txt" + cp 2gram.fst.txt $abs_repo/data/lang_bpe_500/. + popd + + lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + log "Download pre-trained RNN-LM model from ${lm_repo_url}" + GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url + lm_repo=$(basename $lm_repo_url) + pushd $lm_repo + git lfs pull --include "exp/pretrained.pt" + mv exp/pretrained.pt exp/epoch-88.pt + popd + + mkdir -p lstm_transducer_stateless2/exp + ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh lstm_transducer_stateless2/exp + + log "Decoding test-clean and test-other" + + ./lstm_transducer_stateless2/decode.py \ + --use-averaged-model 0 \ + --epoch 999 \ + --avg 1 \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_LODR \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir $lm_repo/exp \ + --rnn-lm-epoch 88 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 +fi + if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then mkdir -p lstm_transducer_stateless2/exp ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index a90841fb6..5f0acf9b8 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -18,7 +18,7 @@ on: jobs: run_librispeech_lstm_transducer_stateless2_2022_09_03: - if: github.event.label.name == 'ready' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -139,9 +139,20 @@ jobs: find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + - name: Display decoding results for lstm_transducer_stateless2 + if: github.event.label.name == 'LODR' + shell: bash + run: | + cd egs/librispeech/ASR + tree lstm_transducer_stateless2/exp + cd lstm_transducer_stateless2/exp + echo "===modified_beam_search_rnnlm_LODR===" + find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + - name: Upload decoding results for lstm_transducer_stateless2 uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' + if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03 path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/ diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index efd60ba81..c2ea3d050 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -318,6 +318,7 @@ The WERs are: | greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 | | modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 | | modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 | +| modified_beam_search + RNNLM shallow fusion | 2.28 | 5.94 | --iter 468000 --avg 16 | | fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 | | greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 | | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | @@ -393,6 +394,32 @@ for iter in 472000; do done done +You may also decode using LODR + RNNLM shallow fusion. This decoding method is proposed in . +It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be +generated by `generate-lm.sh`, or you may download it from . + +The decoding command is as follows: + +for iter in 472000; do + for avg in 8 10 12 14 16 18; do + ./lstm_transducer_stateless2/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_LODR \ + --beam 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --token-ngram 2 \ + --ngram-lm-scale -0.16 + done +done + Pretrained models, training logs, decoding logs, and decoding results are available at @@ -1912,6 +1939,8 @@ subset so that the gigaspeech dataloader never exhausts. |-------------------------------------|------------|------------|---------------------------------------------| | greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 | | modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + rnnlm shallow fusion | 1.94 | 4.2 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + LODR | 1.83 | 4.03 | --iter 1224000 --avg 14 --max-duration 600 | | fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 | The training commands are: @@ -1957,6 +1986,64 @@ for iter in 1224000; do done done ``` +You may also decode using shallow fusion with external RNNLM. To do so you need to +download a well-trained RNNLM from this link + +```bash +rnn_lm_scale=0.3 + +for iter in 1224000; do + for avg in 14; do + for method in modified_beam_search_rnnlm_shallow_fusion ; do + ./pruned_transducer_stateless3/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --beam 4 \ + --max-contexts 32 \ + --rnn-lm-scale $rnn_lm_scale \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + done + done +done +``` + +If you want to try out with LODR decoding, use the following command. This assums you have a bi-gram LM trained on LibriSpeech text. You can also download the bi-gram LM from here and put it under the directory `data/lang_bpe_500`. + +```bash +rnn_lm_scale=0.4 + +for iter in 1224000; do + for avg in 14; do + for method in modified_beam_search_rnnlm_LODR ; do + ./pruned_transducer_stateless3/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --beam 4 \ + --max-contexts 32 \ + --rnn-lm-scale $rnn_lm_scale \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.14 + done + done +done +``` The pretrained models, training logs, decoding logs, and decoding results can be found at diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 69f695fef..fa5bf1825 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -107,8 +107,25 @@ Usage: --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 -""" +(9) modified beam search with RNNLM shallow fusion + LODR +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --decoding-method modified_beam_search_rnnlm_LODR \ + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ +""" import argparse import logging @@ -132,6 +149,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, modified_beam_search_ngram_rescoring, + modified_beam_search_rnnlm_LODR, modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech @@ -235,7 +253,8 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_rnnlm_shallow_fusion + - modified_beam_search_rnnlm_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -394,7 +413,8 @@ def get_parser(): type=int, default=3, help="""Token Ngram used for rescoring. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", ) parser.add_argument( @@ -402,7 +422,8 @@ def get_parser(): type=int, default=500, help="""ID of the backoff symbol. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", ) add_model_arguments(parser) @@ -572,6 +593,20 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_LODR": + hyp_tokens = modified_beam_search_rnnlm_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -760,6 +795,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_rnnlm_LODR", "modified_beam_search_ngram_rescoring", "modified_beam_search_rnnlm_shallow_fusion", ) @@ -788,6 +824,9 @@ def main(): if "rnnlm" in params.decoding_method: params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "LODR" in params.decoding_method: + params.suffix += "-LODR" + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -901,7 +940,7 @@ def main(): model.eval() # only load N-gram LM when needed - if "ngram" in params.decoding_method: + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: lm_filename = f"{params.tokens_ngram}gram.fst.txt" logging.info(f"lm filename: {lm_filename}") ngram_lm = NgramLm( @@ -910,6 +949,7 @@ def main(): is_binary=False, ) logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale else: ngram_lm = None ngram_lm_scale = None @@ -933,7 +973,6 @@ def main(): ) rnn_lm_model.to(device) rnn_lm_model.eval() - else: rnn_lm_model = None rnn_lm_scale = 0.0 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5e9428b60..59c8ed5b5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -2083,3 +2083,267 @@ def modified_beam_search_rnnlm_shallow_fusion( tokens=ans, timestamps=ans_timestamps, ) + + +def modified_beam_search_rnnlm_LODR( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + LODR_lm: NgramLm, + LODR_lm_scale: float, + rnnlm: RnnLmModel, + rnnlm_scale: float, + beam: int = 4, +) -> List[List[int]]: + """This function implements LODR (https://arxiv.org/abs/2203.16776) with + `modified_beam_search`. It uses a bi-gram language model as the estimate + of the internal language model and subtracts its score during shallow fusion + with an external language model. This implementation uses a RNNLM as the + external language model. + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LODR_lm: + A low order n-gram LM + LODR_lm_scale: + The scale of the LODR_lm + rnnlm (RnnLmModel): + RNNLM, the external language model + rnnlm_scale (float): + scale of RNNLM in shallow fusion + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert rnnlm is not None + lm_scale = rnnlm_scale + vocab_size = rnnlm.vocab_size + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + init_score, init_states = rnnlm.score_token(sos_token) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, # state of the RNNLM + lm_score=init_score.reshape(-1), + state_cost=NgramLmStateCost( + LODR_lm + ), # state of the source domain ngram + ) + ) + + rnnlm.clean_cache() + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + The RNNLM will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + assert new_token != 0, new_token + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + + # forward RNNLM to get new states and scores + if len(token_list) != 0: + tokens_to_score = ( + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) + ) + + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + # current score of hyp + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + + # calculate the score of the latest token + current_ngram_score = ( + state_cost.lm_score - hyp.state_cost.lm_score + ) + + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.state_cost.lm_score, + ) + # score = score + RNNLM_score - LODR_score + # LODR_LM_scale is a negative number here + hyp_log_prob += ( + lm_score[new_token] * lm_scale + + LODR_lm_scale * current_ngram_score + ) # add the lm score + + lm_score = scores[count] + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + else: + state_cost = hyp.state_cost + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + state_cost=state_cost, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 03137501f..e00aab34a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -90,8 +91,40 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 -""" +(8) modified beam search (with RNNLM shallow fusion) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(9) modified beam search with RNNLM shallow fusion + LODR +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --decoding-method modified_beam_search_rnnlm_LODR \ + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ +""" import argparse import logging @@ -116,10 +149,14 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_ngram_rescoring, + modified_beam_search_rnnlm_LODR, + modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model +from icefall import NgramLm from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel @@ -202,6 +239,9 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_ngram_rescoring + - modified_beam_search_rnnlm_shallow_fusion + - modified_beam_search_rnnlm_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -263,6 +303,7 @@ def get_parser(): default=2, help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -341,6 +382,15 @@ def get_parser(): """, ) + parser.add_argument( + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion. + It specifies the path to RNN LM exp dir. + """, + ) + parser.add_argument( "--rnn-lm-exp-dir", type=str, @@ -397,6 +447,24 @@ def get_parser(): """, ) + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -410,7 +478,10 @@ def decode_one_batch( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None, - rnn_lm_model: torch.nn.Module = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + rnn_lm_model: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -444,6 +515,14 @@ def decode_one_batch( fast_beam_search_nbest, fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. It an FsaVec containing an acceptor. + rnn_lm_model: + A rnnlm which can be used for rescoring or shallow fusion + rnnlm_scale: + The scale of the rnnlm. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -607,6 +686,43 @@ def decode_one_batch( nbest_scale=params.nbest_scale, temperature=params.temperature, ) + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + rnnlm=rnn_lm_model, + rnnlm_scale=rnnlm_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_LODR": + hyp_tokens = modified_beam_search_rnnlm_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + rnnlm=rnn_lm_model, + rnnlm_scale=rnnlm_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -693,7 +809,10 @@ def decode_dataset( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None, - rnn_lm_model: torch.nn.Module = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + rnn_lm_model: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -749,7 +868,10 @@ def decode_dataset( decoding_graph=decoding_graph, batch=batch, G=G, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, + rnnlm_scale=rnnlm_scale, ) for name, hyps in hyps_dict.items(): @@ -900,6 +1022,9 @@ def main(): "modified_beam_search", "fast_beam_search_with_nbest_rescoring", "fast_beam_search_with_nbest_rnn_rescoring", + "modified_beam_search_rnnlm_LODR", + "modified_beam_search_ngram_rescoring", + "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -930,6 +1055,13 @@ def main(): params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-temperature-{params.temperature}" + if "rnnlm" in params.decoding_method: + params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "LODR" in params.decoding_method: + params.suffix += "-LODR" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -1048,6 +1180,44 @@ def main(): word_table = None rnn_lm_model = None + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load rnnlm if used + if "rnnlm" in params.decoding_method: + rnn_lm_scale = params.rnn_lm_scale + + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + assert params.rnn_lm_avg == 1 + + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + rnn_lm_model.eval() + else: + rnn_lm_model = None + rnn_lm_scale = 0.0 + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -1074,7 +1244,10 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, G=G, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, + rnnlm_scale=rnn_lm_scale, ) save_results(