diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode.py b/egs/librispeech/WSASR/conformer_ctc2/decode.py index 45c630f59..b883e448a 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/decode.py +++ b/egs/librispeech/WSASR/conformer_ctc2/decode.py @@ -2,6 +2,7 @@ # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, # Fangjun Kuang, # Quandong Wang) +# 2023 Johns Hopkins University (Author: Dongji Gao) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -40,17 +41,10 @@ from icefall.checkpoint import ( ) from icefall.decode import ( get_lattice, - nbest_decoding, - nbest_oracle, one_best_decoding, - rescore_with_attention_decoder, - rescore_with_n_best_list, - rescore_with_rnn_lm, - rescore_with_whole_lattice, ) from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, get_texts, @@ -119,22 +113,7 @@ def get_parser(): model for decoding. It produces the same results with ctc-decoding. - (2) 1best. Extract the best path from the decoding lattice as the decoding result. - - (3) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (4) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (5) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - - (6) attention-decoder. Extract n paths from the LM rescored - lattice, the path with the highest score is the decoding result. - - (7) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume - you have trained an RNN LM using ./rnn_lm/train.py - - (8) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. - """, + """, ) parser.add_argument( @@ -157,28 +136,6 @@ def get_parser(): """, ) - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle - A smaller value results in more unique paths. - """, - ) - parser.add_argument( "--exp-dir", type=str, default="conformer_ctc2/exp", help="The experiment dir", ) @@ -196,59 +153,6 @@ def get_parser(): """, ) - parser.add_argument( - "--rnn-lm-exp-dir", - type=str, - default="rnn_lm/exp", - help="""Used only when --method is rnn-lm. - It specifies the path to RNN LM exp dir. - """, - ) - - parser.add_argument( - "--rnn-lm-epoch", - type=int, - default=7, - help="""Used only when --method is rnn-lm. - It specifies the checkpoint to use. - """, - ) - - parser.add_argument( - "--rnn-lm-avg", - type=int, - default=2, - help="""Used only when --method is rnn-lm. - It specifies the number of checkpoints to average. - """, - ) - - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", type=int, default=2048, help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=4, - help="Number of RNN layers the model", - ) - parser.add_argument( - "--rnn-lm-tie-weights", - type=str2bool, - default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - return parser @@ -256,8 +160,8 @@ def get_params() -> AttributeDict: params = AttributeDict( { # parameters for conformer - "subsampling_factor": 2, - "feature_dim": 768, + "subsampling_factor": 4, + "feature_dim": 80, "nhead": 8, "dim_feedforward": 2048, "encoder_dim": 512, @@ -319,7 +223,6 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: def decode_one_batch( params: AttributeDict, model: nn.Module, - rnn_lm_model: Optional[nn.Module], HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], @@ -345,15 +248,9 @@ def decode_one_batch( It's the return value of :func:`get_params`. - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. model: The neural model. - rnn_lm_model: - The neural model for RNN LM. HLG: The decoding graph. Used only when params.method is NOT ctc-decoding. H: @@ -458,121 +355,23 @@ def decode_one_batch( key = "ctc-greedy-search" return {key: hyps} - if params.method == "nbest-oracle": - # Note: You can also pass rescored lattices to it. - # We choose the HLG decoded lattice for speed reasons - # as HLG decoding is faster and the oracle WER - # is only slightly worse than that of rescored lattices. - best_path = nbest_oracle( - lattice=lattice, - num_paths=params.num_paths, - ref_texts=supervisions["text"], - word_table=word_table, - nbest_scale=params.nbest_scale, - oov="", + if params.method in ["1best"]: + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa - return {key: hyps} - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + key = "no_rescore" hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} - - assert params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - "rnn-lm", - ] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - elif params.method == "whole-lattice-rescoring": - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list, - ) - elif params.method == "attention-decoder": - # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None, - ) - # TODO: pass `lattice` instead of `rescored_lattice` to - # `rescore_with_attention_decoder` - - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - nbest_scale=params.nbest_scale, - ) - elif params.method == "rnn-lm": - # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None, - ) - - best_path_dict = rescore_with_rnn_lm( - lattice=rescored_lattice, - num_paths=params.num_paths, - rnn_lm_model=rnn_lm_model, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - blank_id=0, - nbest_scale=params.nbest_scale, - ) else: assert False, f"Unsupported decoding method: {params.method}" - ans = dict() - if best_path_dict is not None: - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - else: - ans = None - return ans - - def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - rnn_lm_model: Optional[nn.Module], HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], @@ -590,8 +389,6 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. - rnn_lm_model: - The neural model for RNN LM. HLG: The decoding graph. Used only when params.method is NOT ctc-decoding. H: @@ -630,7 +427,6 @@ def decode_dataset( hyps_dict = decode_one_batch( params=params, model=model, - rnn_lm_model=rnn_lm_model, HLG=HLG, H=H, bpe_model=bpe_model, @@ -774,58 +570,7 @@ def main(): if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone() - if params.method in ( - "nbest-rescoring", - "whole-lattice-rescoring", - "attention-decoder", - "rnn-lm", - ): - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - # Save a dummy value so that it can be loaded in C++. - # See https://github.com/pytorch/pytorch/issues/67902 - # for why we need to do this. - G.dummy = 1 - - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) - G = k2.Fsa.from_dict(d) - - if params.method in [ - "whole-lattice-rescoring", - "attention-decoder", - "rnn-lm", - ]: - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None + G = None model = Conformer( num_features=params.feature_dim, @@ -919,30 +664,6 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - rnn_lm_model = None - if params.method == "rnn-lm": - rnn_lm_model = RnnLmModel( - vocab_size=params.num_classes, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, - ) - if params.rnn_lm_avg == 1: - load_checkpoint( - f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", rnn_lm_model, - ) - rnn_lm_model.to(device) - else: - rnn_lm_model = load_averaged_model( - params.rnn_lm_exp_dir, - rnn_lm_model, - params.rnn_lm_epoch, - params.rnn_lm_avg, - device, - ) - rnn_lm_model.eval() - # we need cut ids to display recognition results. args.return_cuts = True librispeech = LibriSpeechAsrDataModule(args) @@ -961,7 +682,6 @@ def main(): dl=test_dl, params=params, model=model, - rnn_lm_model=rnn_lm_model, HLG=HLG, H=H, bpe_model=bpe_model,