diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 01cc566e8..d569b0752 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -23,7 +23,6 @@ import sentencepiece as spm import torch from model import Transducer -from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding from icefall.rnn_lm.model import RnnLmModel from icefall.utils import add_eos, add_sos, get_texts @@ -658,8 +657,6 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor - state_cost: Optional[NgramLmStateCost] = None - state: Optional = None lm_score: Optional=None @property diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 632932214..59c646717 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -19,36 +19,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search (one best) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ @@ -56,10 +56,10 @@ Usage: --max-states 64 (5) fast beam search (nbest) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 30 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./pruned_transducer_stateless3/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -69,10 +69,10 @@ Usage: --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -82,10 +82,10 @@ Usage: --nbest-scale 0.5 (7) fast beam search (with LG) -./pruned_transducer_stateless5/decode.py \ - --epoch 28 \ +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless5/exp \ + --exp-dir ./lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ @@ -115,6 +115,7 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_rnnlm_shallow_fusion, ) from train import add_model_arguments, get_params, get_transducer_model @@ -125,6 +126,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, setup_logger, @@ -183,7 +185,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless5/exp", + default="lstm_transducer_stateless2/exp", help="The experiment dir", ) @@ -213,6 +215,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified-beam-search3 # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -240,16 +243,6 @@ def get_parser(): """, ) - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - parser.add_argument( "--max-contexts", type=int, @@ -275,6 +268,7 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -302,28 +296,69 @@ def get_parser(): ) parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified_beam_search3. + It specifies the path to RNN LM exp dir. """, ) parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, ) parser.add_argument( - "--left-context", + "--rnn-lm-epoch", type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, ) + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) add_model_arguments(parser) return parser @@ -336,6 +371,8 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -361,7 +398,7 @@ def decode_one_batch( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -474,12 +511,21 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) for i in range(batch_size): # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] # fmt: on if params.decoding_method == "greedy_search": hyp = greedy_search( @@ -523,7 +569,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -538,7 +586,7 @@ def decode_dataset( word_table: The word symbol table. decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used only when --decoding_method is fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: @@ -564,6 +612,7 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + logging.info(f"Decoding {batch_idx}-th batch") hyps_dict = decode_one_batch( params=params, @@ -572,6 +621,8 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, ) for name, hyps in hyps_dict.items(): @@ -597,7 +648,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -657,6 +708,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_sf_rnnlm", ) params.res_dir = params.exp_dir / params.decoding_method @@ -665,10 +717,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -686,6 +734,8 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -706,11 +756,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model") @@ -796,6 +841,25 @@ def main(): model.to(device) model.eval() + rnn_lm_model = None + rnn_lm_scale = params.rnn_lm_scale + if params.decoding_method == "modified_beam_search3": + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + assert params.rnn_lm_avg == 1 + + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + rnn_lm_model.eval() + if "fast_beam_search" in params.decoding_method: if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) @@ -839,6 +903,8 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + rnnlm=rnn_lm_model, + rnnlm_scale=rnn_lm_scale, ) save_results(