From 4634911dc20b4d62bb18c497e2d86fa50ca560f3 Mon Sep 17 00:00:00 2001 From: Quandwang Date: Thu, 21 Jul 2022 17:49:51 +0800 Subject: [PATCH] add rnn lm decoding --- egs/librispeech/ASR/conformer_ctc2/decode.py | 133 ++++++++++++++++++- 1 file changed, 127 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 45fa97812..8a4cad1ad 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -46,13 +46,16 @@ from icefall.decode import ( one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, + rescore_with_rnn_lm, rescore_with_whole_lattice, ) from icefall.env import get_env_info from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, get_texts, + load_averaged_model, setup_logger, store_transcripts, str2bool, @@ -116,7 +119,9 @@ def get_parser(): is the decoding result. - (6) attention-decoder. Extract n paths from the LM rescored lattice, the path with the highest score is the decoding result. - - (7) nbest-oracle. Its WER is the lower bound of any n-best + - (7) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (8) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. """, @@ -148,7 +153,7 @@ def get_parser(): default=100, help="""Number of paths for n-best based decoding method. Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle """, ) @@ -159,7 +164,7 @@ def get_parser(): help="""The scale to be applied to `lattice.scores`. It's needed if you use any kinds of n-best based rescoring. Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle A smaller value results in more unique paths. """, ) @@ -182,11 +187,67 @@ def get_parser(): "--lm-dir", type=str, default="data/lm", - help="""The LM dir. + help="""The n-gram LM dir. It should contain either G_4_gram.pt or G_4_gram.fst.txt """, ) + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + return parser @@ -242,6 +303,7 @@ def ctc_greedy_search( def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py new_hyp: List[int] = [] cur = 0 while cur < len(hyp): @@ -256,6 +318,7 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: def decode_one_batch( params: AttributeDict, model: nn.Module, + rnn_lm_model: Optional[nn.Module], HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], @@ -288,6 +351,8 @@ def decode_one_batch( model: The neural model. + rnn_lm_model: + The neural model for RNN LM. HLG: The decoding graph. Used only when params.method is NOT ctc-decoding. H: @@ -436,6 +501,7 @@ def decode_one_batch( "nbest-rescoring", "whole-lattice-rescoring", "attention-decoder", + "rnn-lm", ] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] @@ -476,6 +542,26 @@ def decode_one_batch( eos_id=eos_id, nbest_scale=params.nbest_scale, ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -494,6 +580,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, + rnn_lm_model: Optional[nn.Module], HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], @@ -511,6 +598,8 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. + rnn_lm_model: + The neural model for RNN LM. HLG: The decoding graph. Used only when params.method is NOT ctc-decoding. H: @@ -548,6 +637,7 @@ def decode_dataset( hyps_dict = decode_one_batch( params=params, model=model, + rnn_lm_model=rnn_lm_model, HLG=HLG, H=H, bpe_model=bpe_model, @@ -596,7 +686,7 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): - if params.method == "attention-decoder": + if params.method in ("attention-decoder", "rnn-lm"): # Set it to False since there are too many logs. enable_log = False else: @@ -700,6 +790,7 @@ def main(): "nbest-rescoring", "whole-lattice-rescoring", "attention-decoder", + "rnn-lm", ): if not (params.lm_dir / "G_4_gram.pt").is_file(): logging.info("Loading G_4_gram.fst.txt") @@ -731,7 +822,11 @@ def main(): d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) G = k2.Fsa.from_dict(d) - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) @@ -836,6 +931,31 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() @@ -852,6 +972,7 @@ def main(): dl=test_dl, params=params, model=model, + rnn_lm_model=rnn_lm_model, HLG=HLG, H=H, bpe_model=bpe_model,