From 154ef4cfa5d369606825d2141d1ca9cda5923ba1 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 29 Oct 2024 15:36:30 +0800 Subject: [PATCH] Support prefix beam search / shallow fussion / hotwords in librispeech ctc decode --- egs/librispeech/ASR/zipformer/ctc_decode.py | 238 ++++++++++++++++++-- 1 file changed, 219 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 9db429959..156989b78 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -111,6 +111,7 @@ Usage: import argparse import logging import math +import os from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -129,8 +130,14 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) + +from icefall.context_graph import ContextGraph, ContextState + from icefall.decode import ( ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, get_lattice, nbest_decoding, nbest_oracle, @@ -140,7 +147,11 @@ from icefall.decode import ( rescore_with_n_best_list, rescore_with_whole_lattice, ) + +from icefall.ngram_lm import NgramLm, NgramLmStateCost from icefall.lexicon import Lexicon +from icefall.lm_wrapper import LmScorer + from icefall.utils import ( AttributeDict, get_texts, @@ -255,6 +266,12 @@ def get_parser(): lattice, rescore them with the attention decoder. - (9) attention-decoder-rescoring-with-ngram. Extract n paths from the LM rescored lattice, rescore them with the attention decoder. + - (10) ctc-prefix-beam-search. Extract n paths with the given beam, the best + path of the n paths is the decoding result. + - (11) ctc-prefix-beam-search-attention-decoder-rescoring. Extract n paths with + the given beam, rescore them with the attention decoder. + - (12) ctc-prefix-beam-search-shallow-fussion. Use NNLM shallow fussion during + beam search, LODR and hotwords are also supported in this decoding method. """, ) @@ -280,6 +297,23 @@ def get_parser(): """, ) + parser.add_argument( + "--nnlm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--nnlm-scale", + type=float, + default=0, + help="""The scale of the neural network LM, 0 means don't use nnlm shallow fussion. + Used only when `--use-shallow-fusion` is set to True. + """, + ) + parser.add_argument( "--hlg-scale", type=float, @@ -297,11 +331,52 @@ def get_parser(): """, ) + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--lodr-ngram", + type=str, + help="The path to the lodr ngram", + ) + + parser.add_argument( + "--lodr-lm-scale", + type=float, + default=0, + help="The scale of lodr ngram, should be less than 0. 0 means don't use lodr.", + ) + + parser.add_argument( + "--context-score", + type=float, + default=0, + help=""" + The bonus score of each token for the context biasing words/phrases. + 0 means don't use contextual biasing. + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is ctc-prefix-beam-search-shallow-fussion. + """, + ) + parser.add_argument( "--skip-scoring", type=str2bool, default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""" + help="""Skip scoring, but still save the ASR output (for eval sets).""", ) add_model_arguments(parser) @@ -314,11 +389,12 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "beam": 4, # for prefix-beam-search } ) return params @@ -333,6 +409,9 @@ def decode_one_batch( batch: dict, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -377,10 +456,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = params.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -411,6 +487,51 @@ def decode_one_batch( key = "ctc-greedy-search" return {key: hyps} + if params.decoding_method == "ctc-prefix-beam-search": + token_ids = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, token_ids in best_path_dict.items(): + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + token_ids = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + NNLM=NNLM, + LODR_lm=LODR_lm, + LODR_lm_scale=params.lodr_lm_scale, + context_graph=context_graph, + ) + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -584,6 +705,9 @@ def decode_dataset( bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, + NNLM: Optional[LmScorer] = None, + LODR_lm: Optional[NgramLm] = None, + context_graph: Optional[ContextGraph] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -634,6 +758,9 @@ def decode_dataset( batch=batch, word_table=word_table, G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) for name, hyps in hyps_dict.items(): @@ -664,9 +791,7 @@ def save_asr_output( """ for key, results in results_dict.items(): - recogs_filename = ( - params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - ) + recogs_filename = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" results = sorted(results) store_transcripts(filename=recogs_filename, texts=results) @@ -680,7 +805,8 @@ def save_wer_results( results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], ): if params.decoding_method in ( - "attention-decoder-rescoring-with-ngram", "whole-lattice-rescoring" + "attention-decoder-rescoring-with-ngram", + "whole-lattice-rescoring", ): # Set it to False since there are too many logs. enable_log = False @@ -721,6 +847,7 @@ def save_wer_results( def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -735,8 +862,11 @@ def main(): set_caching_enabled(True) # lhotse assert params.decoding_method in ( - "ctc-greedy-search", "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", "1best", "nbest", "nbest-rescoring", @@ -762,6 +892,16 @@ def main(): params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + if params.nnlm_scale != 0: + params.suffix += f"_nnlm-scale-{params.nnlm_scale}" + if params.lodr_lm_scale != 0: + params.suffix += f"_lodr-scale-{params.lodr_lm_scale}" + if params.context_score != 0: + params.suffix += f"_context_score-{params.context_score}" + if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -771,6 +911,7 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + params.device = device logging.info(f"Device: {device}") logging.info(params) @@ -786,14 +927,24 @@ def main(): params.sos_id = 1 if params.decoding_method in [ - "ctc-greedy-search", "ctc-decoding", "attention-decoder-rescoring-no-ngram" + "ctc-decoding", + "ctc-greedy-search", + "ctc-prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", + "attention-decoder-rescoring-no-ngram", ]: HLG = None - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) + H = None + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) bpe_model = spm.SentencePieceProcessor() bpe_model.load(str(params.lang_dir / "bpe.model")) else: @@ -844,7 +995,8 @@ def main(): G = k2.Fsa.from_dict(d) if params.decoding_method in [ - "whole-lattice-rescoring", "attention-decoder-rescoring-with-ngram" + "whole-lattice-rescoring", + "attention-decoder-rescoring-with-ngram", ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later @@ -858,6 +1010,51 @@ def main(): else: G = None + # only load the neural network LM if required + NNLM = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.nnlm_scale != 0 + ): + NNLM = LmScorer( + lm_type=params.nnlm_type, + params=params, + device=device, + lm_scale=params.nnlm_scale, + ) + NNLM.to(device) + NNLM.eval() + + LODR_lm = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.lodr_lm_scale != 0 + ): + assert os.path.exists( + params.lodr_ngram + ), f"LODR ngram does not exists, given path : {params.lodr_ngram}" + logging.info(f"Loading LODR (token level lm): {params.lodr_ngram}") + LODR_lm = NgramLm( + params.lodr_ngram, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {LODR_lm.lm.num_states}") + + context_graph = None + if ( + params.decoding_method == "ctc-prefix-beam-search-shallow-fussion" + and params.context_score != 0 + ): + assert os.path.exists( + params.context_file + ), f"context_file does not exists, given path : {params.context_file}" + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(bpe_model.encode(line.strip())) + context_graph = ContextGraph(params.context_score) + context_graph.build(contexts) + logging.info("About to create model") model = get_model(params) @@ -967,6 +1164,9 @@ def main(): bpe_model=bpe_model, word_table=lexicon.word_table, G=G, + NNLM=NNLM, + LODR_lm=LODR_lm, + context_graph=context_graph, ) save_asr_output(