From 45c13e90e42d0f6ff190d69acb18f4e868bfa954 Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 24 Apr 2023 15:00:02 +0800 Subject: [PATCH] RNNLM rescore + Low-order density ratio (#1017) * add rnnlm rescore + LODR * add LODR in decode.py * update RESULTS --- egs/librispeech/ASR/RESULTS.md | 38 ++- .../beam_search.py | 218 +++++++++++++++++- .../decode.py | 99 +++++++- 3 files changed, 345 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 5a956fc9c..ef817d5dd 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -215,11 +215,12 @@ done We also support decoding with neural network LMs. After combining with language models, the WERs are | decoding method | chunk size | test-clean | test-other | comment | decoding mode | |----------------------|------------|------------|------------|---------------------|----------------------| -| modified beam search | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | -| modified beam search + RNNLM shallow fusion | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming | -| modified beam search + RNNLM nbest rescore | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search` | 320ms | 3.11 | 7.93 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_shallow_fusion` | 320ms | 2.58 | 6.65 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_rescore` | 320ms | 2.59 | 6.86 | --epoch 30 --avg 9 | simulated streaming | +| `modified_beam_search_lm_rescore_LODR` | 320ms | 2.52 | 6.73 | --epoch 30 --avg 9 | simulated streaming | -Please use the following command for RNNLM shallow fusion: +Please use the following command for `modified_beam_search_lm_shallow_fusion`: ```bash for lm_scale in $(seq 0.15 0.01 0.38); do for beam_size in 4 8 12; do @@ -246,7 +247,7 @@ for lm_scale in $(seq 0.15 0.01 0.38); do done ``` -Please use the following command for RNNLM rescore: +Please use the following command for `modified_beam_search_lm_rescore`: ```bash ./pruned_transducer_stateless7_streaming/decode.py \ --epoch 30 \ @@ -268,7 +269,32 @@ Please use the following command for RNNLM rescore: --lm-vocab-size 500 ``` -A well-trained RNNLM can be found here: . +Please use the following command for `modified_beam_search_lm_rescore_LODR`: +```bash +./pruned_transducer_stateless7_streaming/decode.py \ + --epoch 30 \ + --avg 9 \ + --use-averaged-model True \ + --beam-size 8 \ + --exp-dir ./pruned_transducer_stateless7_streaming/exp \ + --max-duration 600 \ + --decode-chunk-len 32 \ + --decoding-method modified_beam_search_lm_rescore_LODR \ + --use-shallow-fusion 0 \ + --lm-type rnn \ + --lm-exp-dir rnn_lm/exp \ + --lm-epoch 99 \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 3 \ + --lm-vocab-size 500 \ + --tokens-ngram 2 \ + --backoff-id 500 +``` + +A well-trained RNNLM can be found here: . The bi-gram used in LODR decoding +can be found here: . #### Smaller model diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index c44a2ad3e..e45f2c652 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1244,7 +1244,7 @@ def modified_beam_search_lm_rescore( # get the best hyp with different lm_scale for lm_scale in lm_scale_list: - key = f"nnlm_scale_{lm_scale}" + key = f"nnlm_scale_{lm_scale:.2f}" tot_scores = am_scores.values + lm_scores * lm_scale ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) max_indexes = ragged_tot_scores.argmax().tolist() @@ -1257,6 +1257,222 @@ def modified_beam_search_lm_rescore( return ans +def modified_beam_search_lm_rescore_LODR( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + LM: LmScorer, + LODR_lm: NgramLm, + sp: spm.SentencePieceProcessor, + lm_scale_list: List[int], + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + Rescore the final results with RNNLM and return the one with the highest score + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + LM: + A neural network language model + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + + # get the am_scores for n-best list + hyps_shape = get_hyps_shape(B) + am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) + am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) + + # now LM rescore + # prepare input data to LM + candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] + possible_seqs = k2.RaggedTensor(candidate_seqs) + row_splits = possible_seqs.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) + possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) + sentence_token_lengths += 1 + + x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) + y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) + x = x.to(device).to(torch.int64) + y = y.to(device).to(torch.int64) + sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) + + lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) + assert lm_scores.ndim == 2 + lm_scores = -1 * lm_scores.sum(dim=1) + + # now LODR scores + import math + + LODR_scores = [] + for seq in candidate_seqs: + tokens = " ".join(sp.id_to_piece(seq)) + LODR_scores.append(LODR_lm.score(tokens)) + LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( + 10 + ) # arpa scores are 10-based + assert lm_scores.shape == LODR_scores.shape + + ans = {} + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + + LODR_scale_list = [0.05 * i for i in range(1, 20)] + # get the best hyp with different lm_scale and lodr_scale + for lm_scale in lm_scale_list: + for lodr_scale in LODR_scale_list: + key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" + tot_scores = ( + am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale + ) + ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) + max_indexes = ragged_tot_scores.argmax().tolist() + unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] + hyps = [] + for idx in unsorted_indices: + hyps.append(unsorted_hyps[idx]) + + ans[key] = hyps + return ans + + def _deprecated_modified_beam_search( model: Transducer, encoder_out: torch.Tensor, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py index 8aa0d8689..3444f8193 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/decode.py @@ -123,10 +123,13 @@ from beam_search import ( greedy_search_batch, modified_beam_search, modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -134,7 +137,6 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon -from icefall.lm_wrapper import LmScorer from icefall.utils import ( AttributeDict, setup_logger, @@ -336,6 +338,21 @@ def get_parser(): """, ) + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + add_model_arguments(parser) return parser @@ -349,6 +366,8 @@ def decode_one_batch( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -483,6 +502,18 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -493,6 +524,18 @@ def decode_one_batch( LM=LM, lm_scale_list=lm_scale_list, ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) else: batch_size = encoder_out.size(0) @@ -531,7 +574,10 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} - elif params.decoding_method == "modified_beam_search_lm_rescore": + elif params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): ans = dict() assert ans_dict is not None for key, hyps in ans_dict.items(): @@ -550,6 +596,8 @@ def decode_dataset( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -568,6 +616,8 @@ def decode_dataset( The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + ngram_lm: + A n-gram LM to be used for LODR. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -600,6 +650,8 @@ def decode_dataset( word_table=word_table, batch=batch, LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) for name, hyps in hyps_dict.items(): @@ -677,8 +729,10 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_LODR", "modified_beam_search_lm_shallow_fusion", "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -822,7 +876,12 @@ def main(): model.eval() # only load the neural network LM if required - if params.use_shallow_fusion or "lm" in params.decoding_method: + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): LM = LmScorer( lm_type=params.lm_type, params=params, @@ -834,6 +893,35 @@ def main(): else: LM = None + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -866,8 +954,10 @@ def main(): test_sets = ["test-clean", "test-other"] test_dl = [test_clean_dl, test_other_dl] + import time for test_set, test_dl in zip(test_sets, test_dl): + start = time.time() results_dict = decode_dataset( dl=test_dl, params=params, @@ -876,7 +966,10 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, ) + logging.info(f"Elasped time for {test_set}: {time.time() - start}") save_results( params=params,