From 0a46a39e24a687487eeab3396d35fd395b156a0c Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 17:25:31 +0800 Subject: [PATCH] update decoding commands --- .../ASR/lstm_transducer_stateless2/decode.py | 33 +++--- .../beam_search.py | 105 +----------------- .../pruned_transducer_stateless5/decode.py | 95 +++++++++++----- 3 files changed, 88 insertions(+), 145 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 1d46c0177..c43328e08 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -91,6 +91,21 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +(8) modified beam search (with RNNLM shallow fusion) +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 """ @@ -121,7 +136,6 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall import NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -389,8 +403,6 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: @@ -526,11 +538,12 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search_sf_rnnlm": - hyp_tokens = modified_beam_search_sf_rnnlm_batched( + elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + beam=params.beam_size, sp=sp, rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, @@ -586,9 +599,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, - ngram_lm: Optional[NgramLm] = None, - ngram_lm_scale: float = 1.0, - rnnlm: Optional[NgramLm] = None, + rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -642,8 +653,6 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, ) @@ -731,7 +740,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_sf_rnnlm", + "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -942,8 +951,6 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, - ngram_lm=ngram_lm, - ngram_lm_scale=params.ngram_lm_scale, rnnlm=rnn_lm_model, rnnlm_scale=rnn_lm_scale, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index d569b0752..e454bc1a6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1,4 +1,5 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -656,6 +657,7 @@ class Hypothesis: # The log prob of ys. # It contains only one entry. log_prob: torch.Tensor + state: Optional=None lm_score: Optional=None @@ -1542,107 +1544,6 @@ def fast_beam_search_with_nbest_rnn_rescoring( ans[key] = hyps return ans - -def modified_beam_search_sf_rnnlm( - model: Transducer, - encoder_out: torch.Tensor, - sp, - rnnlm: RnnLmModel, - rnnlm_scale: float, - beam: int = 4, -): - encoder_out = model.joiner.encoder_proj(encoder_out) - lm_scale = rnnlm_scale - - assert rnnlm is not None - assert encoder_out.ndim == 2, encoder_out.shape - rnnlm.clean_cache() - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") - eos_id = sp.piece_to_id("") - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - B = HypothesisList() - B.add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - T = encoder_out.shape[0] - for t in range(T): - current_encoder_out = encoder_out[t : t + 1] - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyp in A] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - decoder_out = model.decoder(decoder_input, need_pad=False).squeeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - # decoder_out is of shape (num_hyps, joiner_dim) - current_encoder_out = current_encoder_out.repeat(len(A), 1) - # current_encoder_out is of shape (num_hyps, encoder_out_dim) - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk( - beam - ) # get topk tokens and scores - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[hyp_idx] # get hyp - new_ys = hyp.ys[:] - state = "ys=" + "+".join(list(map(str, new_ys))) - tokens = k2.RaggedTensor([new_ys[context_size:]]) - - lm_score = rnnlm.predict( - tokens, state, sos_id, eos_id, blank_id - ) # get rnnlm score - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] # get token - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - # state_cost = hyp.state_cost.forward_one_step(new_token) - hyp_log_prob += ( - lm_score[new_token] * lm_scale - ) # add the lm score - else: - new_ys = new_ys - new_log_prob = hyp_log_prob - - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - ) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - return best_hyp.ys[context_size:] def modified_beam_search_rnnlm_shallow_fusion( model: Transducer, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 59c646717..8c69cfd6e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # # Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Zengwei Yao, +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,47 +20,43 @@ """ Usage: (1) greedy search -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method greedy_search - (2) beam search (not recommended) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 - (3) modified beam search -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 - (4) fast beam search (one best) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ --max-contexts 8 \ --max-states 64 - (5) fast beam search (nbest) -./lstm_transducer_stateless2/decode.py \ - --epoch 30 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -67,12 +64,11 @@ Usage: --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 - (6) fast beam search (nbest oracle WER) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -80,17 +76,34 @@ Usage: --max-states 64 \ --num-paths 200 \ --nbest-scale 0.5 - (7) fast beam search (with LG) -./lstm_transducer_stateless2/decode.py \ - --epoch 35 \ +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ --avg 15 \ - --exp-dir ./lstm_transducer_stateless2/exp \ + --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ --max-contexts 8 \ --max-states 64 + +(8) modified beam search with RNNLM shallow fusion (with LG) +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + + """ @@ -243,6 +256,16 @@ def get_parser(): """, ) + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -294,6 +317,15 @@ def get_parser(): Used only when the decoding method is fast_beam_search_nbest, fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) parser.add_argument( "--rnn-lm-scale", @@ -517,6 +549,9 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, + sp=sp, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -708,7 +743,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", - "modified_beam_search_sf_rnnlm", + "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -843,7 +878,7 @@ def main(): rnn_lm_model = None rnn_lm_scale = params.rnn_lm_scale - if params.decoding_method == "modified_beam_search3": + if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim,