From 9a01b9098deb56c9c4b048c000b4eead756c98f5 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Wed, 2 Nov 2022 18:03:56 +0800 Subject: [PATCH] include previous added decoding method --- .../ASR/lstm_transducer_stateless2/decode.py | 65 +++++++++++++++---- 1 file changed, 51 insertions(+), 14 deletions(-) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index fc077f062..20a5ebd8b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -131,11 +131,13 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_ngram_rescoring, modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model +from icefall import NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -232,6 +234,7 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_ngram_rescoring - modified-beam-search_rnnlm_shallow_fusion # for rnn lm shallow fusion If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. @@ -386,7 +389,23 @@ def get_parser(): last output linear layer """, ) - parser.add_argument("--ilm-scale", type=float, default=-0.1) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -399,6 +418,8 @@ def decode_one_batch( batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: @@ -534,6 +555,17 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( model=model, @@ -595,9 +627,11 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, rnnlm: Optional[RnnLmModel] = None, rnnlm_scale: float = 1.0, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -638,13 +672,6 @@ def decode_dataset( for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration = sum( - [cut.duration for cut in batch["supervisions"]["cut"]] - ) - - logging.info( - f"Decoding {batch_idx}-th batch, batch size is {len(cut_ids)}, total duration is {total_duration}" - ) hyps_dict = decode_one_batch( params=params, @@ -653,6 +680,8 @@ def decode_dataset( decoding_graph=decoding_graph, word_table=word_table, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, rnnlm=rnnlm, rnnlm_scale=rnnlm_scale, ) @@ -680,7 +709,7 @@ def decode_dataset( def save_results( params: AttributeDict, test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): @@ -740,6 +769,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_ngram_rescoring", "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -765,13 +795,10 @@ def main(): else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" if "rnnlm" in params.decoding_method: params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" - if "ILME" in params.decoding_method: - params.suffix += f"-ILME-scale={params.ilm_scale}" - if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -884,6 +911,14 @@ def main(): model.to(device) model.eval() + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") # only load rnnlm if used if "rnnlm" in params.decoding_method: rnn_lm_scale = params.rnn_lm_scale @@ -951,6 +986,8 @@ def main(): sp=sp, word_table=word_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=params.ngram_lm_scale, rnnlm=rnn_lm_model, rnnlm_scale=rnn_lm_scale, )