diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 73c5503d8..ad6f164ea 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -528,10 +528,56 @@ displays the help information. It supports three decoding methods: + - CTC decoding - HLG decoding - HLG + n-gram LM rescoring - HLG + n-gram LM rescoring + attention decoder rescoring +CTC decoding +^^^^^^^^^^^^ + +CTC decoding uses the best path of the decoding lattice as the decoding result +without any LM or lexicon. + +The command to run CTC decoding is: + +.. code-block:: bash + + $ cd egs/librispeech/ASR + $ ./conformer_ctc/pretrained.py \ + --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ + --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac + +The output is given below: + +.. code-block:: + + 2021-10-09 21:06:57,154 INFO [pretrained.py:253] device: cuda:0 + 2021-10-09 21:06:57,154 INFO [pretrained.py:255] Creating model + 2021-10-09 21:07:04,234 INFO [pretrained.py:272] Constructing Fbank computer + 2021-10-09 21:07:04,235 INFO [pretrained.py:282] Reading sound files: ['./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac'] + 2021-10-09 21:07:04,248 INFO [pretrained.py:288] Decoding started + 2021-10-09 21:07:05,041 INFO [pretrained.py:306] Building CTC topology + 2021-10-09 21:07:05,334 INFO [lexicon.py:113] Loading pre-compiled tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/Linv.pt + 2021-10-09 21:07:05,380 INFO [pretrained.py:315] Loading BPE model + 2021-10-09 21:07:07,905 INFO [pretrained.py:330] Use CTC decoding + 2021-10-09 21:07:07,918 INFO [pretrained.py:407] + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac: + AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS + + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac: + GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED + BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN + + ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac: + YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION + + + 2021-10-09 21:07:07,918 INFO [pretrained.py:409] Decoding Done + HLG decoding ^^^^^^^^^^^^ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 00812d674..31bd9bdf3 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -19,6 +20,7 @@ import argparse import logging import math +import sentencepiece as spm from typing import List import k2 @@ -28,6 +30,7 @@ import torchaudio from conformer import Conformer from torch.nn.utils.rnn import pad_sequence +from icefall.lexicon import Lexicon from icefall.decode import ( get_lattice, one_best_decoding, @@ -54,12 +57,17 @@ def get_parser(): parser.add_argument( "--words-file", type=str, - required=True, + default="./tmp/icefall_asr_librispeech_conformer_ctc/ \ + data/lang_bpe/words.txt", help="Path to words.txt", ) parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." + "--HLG", + type=str, + default="./tmp/icefall_asr_librispeech_conformer_ctc/ \ + data/lang_bpe/HLG.pt", + help="Path to HLG.pt.", ) parser.add_argument( @@ -68,6 +76,10 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. @@ -157,6 +169,14 @@ def get_parser(): """, ) + parser.add_argument( + "--lang-dir", + type=str, + default="./tmp/icefall_asr_librispeech_conformer_ctc/ \ + data/lang_bpe", + help="Path to lang bpe dir.", + ) + parser.add_argument( "sound_files", type=str, @@ -249,23 +269,6 @@ def main(): model.to(device) model.eval() - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = device @@ -299,52 +302,103 @@ def main(): dtype=torch.int32, ) - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) + if params.method == "ctc-decoding": + logging.info("Building CTC topology") + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) - if params.method == "1best": - logging.info("Use HLG decoding") + logging.info("Loading BPE model") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir + "/bpe.model")) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + logging.info("Use CTC decoding") best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "attention-decoder": - logging.info("Use HLG + LM rescoring + attention decoder rescoring") - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None - ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - nbest_scale=params.nbest_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] s = "\n" for filename, hyp in zip(params.sound_files, hyps): @@ -361,4 +415,4 @@ if __name__ == "__main__": ) logging.basicConfig(format=formatter, level=logging.INFO) - main() + main() \ No newline at end of file