From 33fa9e8b006b64986bcea43f303febbb6ca6cfa9 Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 29 Sep 2024 12:00:45 +0800 Subject: [PATCH] Add prefix beam search to aishell --- egs/aishell/ASR/zipformer/ctc_decode.py | 112 +++++++++++++++++++-- egs/gigaspeech/ASR/zipformer/ctc_decode.py | 6 +- 2 files changed, 108 insertions(+), 10 deletions(-) diff --git a/egs/aishell/ASR/zipformer/ctc_decode.py b/egs/aishell/ASR/zipformer/ctc_decode.py index 8073aa84b..01df090ab 100755 --- a/egs/aishell/ASR/zipformer/ctc_decode.py +++ b/egs/aishell/ASR/zipformer/ctc_decode.py @@ -123,6 +123,11 @@ from lhotse import set_caching_enabled from lhotse.cut import Cut from train import add_model_arguments, get_model, get_params +from icefall.context_graph import ContextGraph, ContextState +from icefall.ngram_lm import NgramLm, NgramLmStateCost +from icefall.lm_wrapper import LmScorer + + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -131,6 +136,9 @@ from icefall.checkpoint import ( ) from icefall.decode import ( ctc_greedy_search, + ctc_prefix_beam_search, + ctc_prefix_beam_search_attention_decoder_rescoring, + ctc_prefix_beam_search_shallow_fussion, get_lattice, one_best_decoding, rescore_with_attention_decoder_no_ngram, @@ -249,7 +257,24 @@ def get_parser(): "--skip-scoring", type=str2bool, default=False, - help="""Skip scoring, but still save the ASR output (for eval sets).""" + help="""Skip scoring, but still save the ASR output (for eval sets).""", + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, ) add_model_arguments(parser) @@ -262,8 +287,9 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, + "search_beam": 20, # for k2 fsa composition + "output_beam": 8, # for k2 fsa composition + "beam": 4, # for prefix-beam-search "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, @@ -278,6 +304,7 @@ def decode_one_batch( lexicon: Lexicon, batch: dict, H: Optional[k2.Fsa], + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -355,6 +382,43 @@ def decode_one_batch( key = "ctc-greedy-search" return {key: hyps} + if params.decoding_method == "prefix-beam-search": + hyp_tokens = ctc_prefix_beam_search( + ctc_output=ctc_output, encoder_out_lens=encoder_out_lens + ) + hyps = [] + for i in range(batch_size): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + key = "prefix-beam-search" + return {key: hyps} + + if params.decoding_method == "ctc-prefix-beam-search-attention-decoder-rescoring": + best_path_dict = ctc_prefix_beam_search_attention_decoder_rescoring( + ctc_output=ctc_output, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + ans = dict() + for a_scale_str, hyp_tokens in best_path_dict.items(): + hyps = [] + for i in range(batch_size): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + ans[a_scale_str] = hyps + return ans + + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + hyp_tokens = ctc_prefix_beam_search_shallow_fussion( + ctc_output=ctc_output, + encoder_out_lens=encoder_out_lens, + LM=LM, + ) + hyps = [] + for i in range(batch_size): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + key = "prefix-beam-search-shallow-fussion" + return {key: hyps} + supervision_segments = torch.stack( ( supervisions["sequence_idx"], @@ -428,6 +492,7 @@ def decode_dataset( model: nn.Module, lexicon: Lexicon, H: Optional[k2.Fsa] = None, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -476,6 +541,7 @@ def decode_dataset( batch=batch, lexicon=lexicon, H=H, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -530,7 +596,9 @@ def save_wer_results( for key, results in results_dict.items(): # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( fd, @@ -545,7 +613,9 @@ def save_wer_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - wer_filename = params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + wer_filename = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) with open(wer_filename, "w", encoding="utf8") as fd: print("settings\tWER", file=fd) @@ -564,6 +634,7 @@ def save_wer_results( def main(): parser = get_parser() AishellAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -578,15 +649,18 @@ def main(): assert params.decoding_method in ( "ctc-greedy-search", + "prefix-beam-search", + "ctc-prefix-beam-search-attention-decoder-rescoring", + "ctc-prefix-beam-search-shallow-fussion", "ctc-decoding", "attention-decoder-rescoring-no-ngram", ) params.res_dir = params.exp_dir / params.decoding_method if params.iter > 0: - params.suffix = f"iter-{params.iter}_avg-{params.avg}" + params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}_avg-{params.avg}" + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.causal: assert ( @@ -598,6 +672,11 @@ def main(): params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" + if "prefix-beam-search" in params.decoding_method: + params.suffix += f"_beam-{params.beam}" + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + params.suffix += f"_lm-scale-{params.lm_scale}" + if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -621,7 +700,10 @@ def main(): params.eos_id = 1 params.sos_id = 1 - if params.decoding_method != "ctc-greedy-search": + if params.decoding_method in [ + "ctc-decoding", + "attention-decoder-rescoring-no-ngram", + ]: H = k2.ctc_topo( max_token=max_token_id, modified=True, @@ -630,6 +712,19 @@ def main(): else: H = None + # only load the neural network LM if required + if params.decoding_method == "ctc-prefix-beam-search-shallow-fussion": + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + logging.info("About to create model") model = get_model(params) @@ -746,6 +841,7 @@ def main(): model=model, H=H, lexicon=lexicon, + LM=LM, ) save_asr_output( diff --git a/egs/gigaspeech/ASR/zipformer/ctc_decode.py b/egs/gigaspeech/ASR/zipformer/ctc_decode.py index 2cb912c66..f9597379b 100755 --- a/egs/gigaspeech/ASR/zipformer/ctc_decode.py +++ b/egs/gigaspeech/ASR/zipformer/ctc_decode.py @@ -1064,11 +1064,13 @@ def main(): gigaspeech = GigaSpeechAsrDataModule(args) test_cuts = gigaspeech.test_cuts() + dev_cuts = gigaspeech.dev_cuts() test_dl = gigaspeech.test_dataloaders(test_cuts) + dev_dl = gigaspeech.test_dataloaders(dev_cuts) - test_sets = ["test"] - test_dls = [test_dl] + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] for test_set, test_dl in zip(test_sets, test_dls): results_dict = decode_dataset(