From 2a52b8c125019feb305275b4e356ea5969a35046 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Thu, 3 Nov 2022 11:10:21 +0800 Subject: [PATCH] update docs --- .../ASR/lstm_transducer_stateless2/decode.py | 35 +++++++++++-------- .../beam_search.py | 25 ++++++++++--- .../pruned_transducer_stateless5/decode.py | 8 ++--- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 20a5ebd8b..40a0d5bf7 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -235,7 +235,7 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -329,7 +329,7 @@ def get_parser(): "--rnn-lm-scale", type=float, default=0.0, - help="""Used only when --method is modified_beam_search3. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -338,7 +338,7 @@ def get_parser(): "--rnn-lm-exp-dir", type=str, default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -347,7 +347,7 @@ def get_parser(): "--rnn-lm-epoch", type=int, default=7, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the checkpoint to use. """, ) @@ -356,7 +356,7 @@ def get_parser(): "--rnn-lm-avg", type=int, default=2, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the number of checkpoints to average. """, ) @@ -911,14 +911,20 @@ def main(): model.to(device) model.eval() - lm_filename = f"{params.tokens_ngram}gram.fst.txt" - logging.info(f"lm filename: {lm_filename}") - ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) - logging.info(f"num states: {ngram_lm.lm.num_states}") + # only load N-gram LM when needed + if "ngram" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + else: + ngram_lm = None + ngram_lm_scale = None + # only load rnnlm if used if "rnnlm" in params.decoding_method: rnn_lm_scale = params.rnn_lm_scale @@ -941,6 +947,7 @@ def main(): else: rnn_lm_model = None + rnn_lm_scale = 0.0 if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": @@ -987,7 +994,7 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, ngram_lm=ngram_lm, - ngram_lm_scale=params.ngram_lm_scale, + ngram_lm_scale=ngram_lm_scale, rnnlm=rnn_lm_model, rnnlm_scale=rnn_lm_scale, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7c5a5ace4..480146a59 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -17,7 +17,7 @@ import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import k2 import sentencepiece as spm @@ -729,8 +729,15 @@ class Hypothesis: # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded - timestamp: List[int] + timestamp: List[int] = None + # the lm score for next token given the current ys + lm_score: Optional[torch.Tensor] = None + + # the RNNLM states (h and c in LSTM) + state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # N-gram LM state state_cost: Optional[NgramLmStateCost] = None @property @@ -1989,8 +1996,15 @@ def modified_beam_search_rnnlm_shallow_fusion( ragged_log_probs = k2.RaggedTensor( shape=log_probs_shape, value=log_probs ) - - # for all hyps with a non-blank new token, score it + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + The RNNLM will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ token_list = [] hs = [] cs = [] @@ -2007,11 +2021,12 @@ def modified_beam_search_rnnlm_shallow_fusion( new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token token_list.append([new_token]) + # store the LSTM states hs.append(hyp.state[0]) cs.append(hyp.state[1]) + # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 8ba36e582..2711c4cc9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -228,7 +228,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -354,7 +354,7 @@ def get_parser(): "--rnn-lm-exp-dir", type=str, default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the path to RNN LM exp dir. """, ) @@ -363,7 +363,7 @@ def get_parser(): "--rnn-lm-epoch", type=int, default=7, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the checkpoint to use. """, ) @@ -372,7 +372,7 @@ def get_parser(): "--rnn-lm-avg", type=int, default=2, - help="""Used only when --method is rnn-lm. + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. It specifies the number of checkpoints to average. """, )