diff --git a/egs/librispeech/ASR/zipformer/ctc_decode.py b/egs/librispeech/ASR/zipformer/ctc_decode.py index 914cb9ffd..85ceb61b8 100755 --- a/egs/librispeech/ASR/zipformer/ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode.py @@ -73,6 +73,29 @@ Usage: --nbest-scale 1.0 \ --lm-dir data/lm \ --decoding-method whole-lattice-rescoring + +(6) attention-decoder-rescoring-no-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --decoding-method attention-decoder-rescoring-no-ngram + +(7) attention-decoder-rescoring-with-ngram +./zipformer/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --use-ctc 1 \ + --use-attention-decoder 1 \ + --max-duration 100 \ + --hlg-scale 0.6 \ + --nbest-scale 1.0 \ + --lm-dir data/lm \ + --decoding-method attention-decoder-rescoring-with-ngram """ @@ -101,10 +124,10 @@ from icefall.decode import ( nbest_decoding, nbest_oracle, one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, rescore_with_attention_decoder_no_ngram, rescore_with_attention_decoder_with_ngram, + rescore_with_n_best_list, + rescore_with_whole_lattice, ) from icefall.lexicon import Lexicon from icefall.utils import ( @@ -214,6 +237,10 @@ def get_parser(): - (6) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. + - (7) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. + - (8) attention-decoder-rescoring-with-ngram. Extract n paths from the LM + rescored lattice, rescore them with the attention decoder. """, ) diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index 2b8d1aaf3..1f3373cd8 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -404,6 +404,7 @@ def main(): token_table = k2.SymbolTable.from_file(params.tokens) params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] params.vocab_size = num_tokens(token_table) + 1 logging.info(params) @@ -466,8 +467,6 @@ def main(): device=device, ) ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: assert params.avg > 0, params.avg start = params.epoch - params.avg diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index 408d13576..4341ef61f 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -81,6 +81,15 @@ Usage of this script: --sample-rate 16000 \ /path/to/foo.wav \ /path/to/bar.wav + +(5) attention-decoder-rescoring-no-ngram +./zipformer/pretrained_ctc.py \ + --checkpoint ./zipformer/exp/pretrained.pt \ + --tokens data/lang_bpe_500/tokens.txt \ + --method attention-decoder-rescoring-no-ngram \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav """ import argparse @@ -100,6 +109,7 @@ from train import add_model_arguments, get_model, get_params from icefall.decode import ( get_lattice, one_best_decoding, + rescore_with_attention_decoder_no_ngram, rescore_with_n_best_list, rescore_with_whole_lattice, ) @@ -172,6 +182,8 @@ def get_parser(): decoding lattice and then use 1best to decode the rescored lattice. We call it HLG decoding + whole-lattice n-gram LM rescoring. + (4) attention-decoder-rescoring-no-ngram. Extract n paths from the decoding + lattice, rescore them with the attention decoder. """, ) @@ -276,6 +288,7 @@ def main(): token_table = k2.SymbolTable.from_file(params.tokens) params.vocab_size = num_tokens(token_table) + 1 # +1 for blank params.blank_id = token_table[""] + params.sos_id = params.eos_id = token_table[""] assert params.blank_id == 0 logging.info(f"{params}") @@ -333,16 +346,13 @@ def main(): dtype=torch.int32, ) - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: max_token_id = params.vocab_size - 1 - H = k2.ctc_topo( max_token=max_token_id, modified=False, device=device, ) - lattice = get_lattice( nnet_output=ctc_output, decoding_graph=H, @@ -354,9 +364,23 @@ def main(): subsampling_factor=params.subsampling_factor, ) - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + else: + logging.info("Use attention decoder rescoring without ngram") + best_path_dict = rescore_with_attention_decoder_no_ngram( + lattice=lattice, + num_paths=params.num_paths, + attention_decoder=model.attention_decoder, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + token_ids = get_texts(best_path) hyps = [[token_table[i] for i in ids] for ids in token_ids] elif params.method in [ @@ -430,7 +454,7 @@ def main(): raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" - if params.method == "ctc-decoding": + if params.method in ["ctc-decoding", "attention-decoder-rescoring-no-ngram"]: for filename, hyp in zip(params.sound_files, hyps): words = "".join(hyp) words = words.replace("▁", " ").strip() diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 213675170..cce058d6c 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -1199,8 +1199,7 @@ def run(rank, world_size, args): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.eos_id = sp.piece_to_id("") - params.sos_id = sp.piece_to_id("") + params.sos_id = params.eos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() if not params.use_transducer: