diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index f2fd18cd4..ce9607991 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -1,5 +1,5 @@ -# How to use a pre-trained model to transcript a sound file +# How to use a pre-trained model to transcribe a sound file or multiple sound files You need to prepare 4 files: @@ -13,6 +13,14 @@ You need to prepare 4 files: Also, you need to install `kaldifeat`. Please refer to for installation. +``` +./conformer_ctc/pretrained.py --help +``` + +displays the help information. + +## HLG decoding + Once you have the above files ready and have `kaldifeat` installed, you can run: @@ -20,7 +28,7 @@ you can run: ./conformer_ctc/pretrained.py \ --checkpoint /path/to/your/checkpoint.pt \ --words-file /path/to/words.txt \ - --hlg /path/to/HLG.pt \ + --HLG /path/to/HLG.pt \ /path/to/your/sound.wav ``` @@ -32,7 +40,60 @@ If you want to transcribe multiple files at the same time, you can use: ./conformer_ctc/pretrained.py \ --checkpoint /path/to/your/checkpoint.pt \ --words-file /path/to/words.txt \ - --hlg /path/to/HLG.pt \ + --HLG /path/to/HLG.pt \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav \ +``` + +**Note**: This is the fastest decoding method. + +## HLG decoding + LM rescoring + +`./conformer_ctc/pretrained.py` also supports `whole lattice LM rescoring` +and `attention decoder rescoring`. + +To use whole lattice LM rescoring, you also need the following files: + + - G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh` + +The command to run decoding with LM rescoring is: + +``` +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + --method whole-lattice-rescoring \ + --G data/lm/G_4_gram.pt \ + --ngram-lm-scale 0.8 \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav \ +``` + +## HLG Decoding + LM rescoring + attention decoder rescoring + +To use attention decoder for rescoring, you need the following extra information: + + - sos token ID + - eos token ID + +The command to run decoding with attention decoder rescoring is: + +``` +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --HLG /path/to/HLG.pt \ + --method attention-decoder \ + --G data/lm/G_4_gram.pt \ + --ngram-lm-scale 1.3 \ + --attention-decoder-scale 1.2 \ + --lattice-score-scale 0.5 \ + --num-paths 100 \ + --sos-id 1 \ + --eos-id 1 \ /path/to/your/sound1.wav \ /path/to/your/sound2.wav \ /path/to/your/sound3.wav \ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 27d9ccc4c..fbdeb39b5 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -12,7 +12,12 @@ import torchaudio from conformer import Conformer from torch.nn.utils.rnn import pad_sequence -from icefall.decode import get_lattice, one_best_decoding +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_whole_lattice, +) from icefall.utils import AttributeDict, get_texts @@ -25,8 +30,8 @@ def get_parser(): "--checkpoint", type=str, required=True, - help="Path to the checkpoint." - "The checkpoint is assume to be saved by " + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " "icefall.checkpoint.save_checkpoint().", ) @@ -38,7 +43,102 @@ def get_parser(): ) parser.add_argument( - "--hlg", type=str, required=True, help="Path to HLG.pt." + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + (3) attention-decoder - Extract n paths from he rescored + lattice and use the transformer attention decoder for + rescoring. + We call it HLG decoding + n-gram LM rescoring + attention + decoder rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or attention-decoder. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and attention-decoder. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--attention-decoder-scale", + type=float, + default=1.2, + help=""" + Used only when method is attention-decoder. + It specifies the scale for attention decoder scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--lattice-score-scale", + type=float, + default=0.5, + help=""" + Used only when method is attention-decoder. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sos-id", + type=float, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the SOS token. + """, + ) + + parser.add_argument( + "--eos-id", + type=float, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the EOS token. + """, ) parser.add_argument( @@ -46,8 +146,8 @@ def get_parser(): type=str, nargs="+", help="The input sound file(s) to transcribe. " - "Supported formats are those that supported by torchaudio.load(). " - "For example, wav, flac are supported. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " "The sample rate has to be 16kHz.", ) @@ -108,6 +208,7 @@ def main(): params = get_params() params.update(vars(args)) + logging.info(f"{params}") device = torch.device("cpu") if torch.cuda.is_available(): @@ -115,7 +216,7 @@ def main(): logging.info(f"device: {device}") - logging.info("Create model") + logging.info("Creating model") model = Conformer( num_features=params.feature_dim, nhead=params.nhead, @@ -134,9 +235,24 @@ def main(): model.to(device) model.eval() - HLG = k2.Fsa.from_dict(torch.load(params.hlg)) + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + G.lm_scores = G.scores.clone() + + logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = device opts.frame_opts.dither = 0 @@ -146,6 +262,7 @@ def main(): fbank = kaldifeat.Fbank(opts) + logging.info(f"Reading sound files: {params.sound_files}") waves = read_sound_files( filenames=params.sound_files, expected_sample_rate=params.sample_rate ) @@ -158,8 +275,9 @@ def main(): features, batch_first=True, padding_value=math.log(1e-10) ) + # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): - nnet_output, _, _ = model(features) + nnet_output, memory, memory_key_padding_mask = model(features) batch_size = nnet_output.shape[0] supervision_segments = torch.tensor( @@ -178,9 +296,37 @@ def main(): subsampling_factor=params.subsampling_factor, ) - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + scale=params.lattice_score_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) diff --git a/icefall/decode.py b/icefall/decode.py index bdcab23f3..8c1eef530 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -546,6 +546,8 @@ def rescore_with_whole_lattice( del lattice.lm_scores assert hasattr(lattice, "lm_scores") is False + assert hasattr(G_with_epsilon_loops, "lm_scores") + # Now, lattice.scores contains only am_scores # inv_lattice has word IDs as labels. @@ -677,10 +679,12 @@ def rescore_with_attention_decoder( num_paths: int, model: nn.Module, memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, + memory_key_padding_mask: Optional[torch.Tensor], sos_id: int, eos_id: int, scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, ) -> Dict[str, k2.Fsa]: """This function extracts n paths from the given lattice and uses an attention decoder to rescore them. The path with the highest @@ -707,6 +711,10 @@ def rescore_with_attention_decoder( scale: It's the scale applied to the lattice.scores. A smaller value yields more unique paths. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. Returns: A dict of FsaVec, whose key contains a string ngram_lm_scale_attention_scale and the value is the @@ -794,11 +802,13 @@ def rescore_with_attention_decoder( path_to_seq_map_long = path_to_seq_map.to(torch.long) expanded_memory = memory.index_select(1, path_to_seq_map_long) - expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( - 0, path_to_seq_map_long - ) + if memory_key_padding_mask is not None: + expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + 0, path_to_seq_map_long + ) + else: + expanded_memory_key_padding_mask = None - # TODO: pass the sos_token_id and eos_token_id via function arguments nll = model.decoder_nll( memory=expanded_memory, memory_key_padding_mask=expanded_memory_key_padding_mask, @@ -813,11 +823,17 @@ def rescore_with_attention_decoder( assert attention_scores.ndim == 1 assert attention_scores.numel() == num_word_seqs - ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + if ngram_lm_scale is None: + ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + ngram_lm_scale_list = [ngram_lm_scale] - attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] - attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + if attention_scale is None: + attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] + attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + else: + attention_scale_list = [attention_scale] path_2axes = k2.ragged.remove_axis(path, 0)