diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 45ad79313..2a956750f 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -429,7 +429,6 @@ After downloading, you will have the following files: |-- README.md |-- data | |-- lang_bpe - | | |-- Linv.pt | | |-- HLG.pt | | |-- bpe.model | | |-- tokens.txt @@ -447,10 +446,6 @@ After downloading, you will have the following files: 6 directories, 11 files **File descriptions**: - - ``data/lang_bpe/Linv.pt`` - - It is the lexicon file, with word IDs as labels and token IDs as aux_labels. - - ``data/lang_bpe/HLG.pt`` It is the decoding graph. @@ -551,7 +546,7 @@ The command to run CTC decoding is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --bpe-model ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/bpe.model \ --method ctc-decoding \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ @@ -595,7 +590,8 @@ The command to run HLG decoding is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac @@ -637,7 +633,8 @@ The command to run HLG decoding + LM rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ --method whole-lattice-rescoring \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.8 \ @@ -684,7 +681,8 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ --method attention-decoder \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 1.3 \ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index be94e6875..edbdb5b2e 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -36,7 +36,6 @@ from icefall.decode import ( rescore_with_attention_decoder, rescore_with_whole_lattice, ) -from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, get_texts @@ -55,10 +54,27 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--words-file", type=str, - required=True, - help="Path to lang dir.", + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, ) parser.add_argument( @@ -287,17 +303,16 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = bpe_model.get_piece_size() - 1 + H = k2.ctc_topo( max_token=max_token_id, modified=False, device=device, ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.lang_dir + "/bpe.model") - lattice = get_lattice( nnet_output=nnet_output, decoding_graph=H, @@ -320,10 +335,8 @@ def main(): "whole-lattice-rescoring", "attention-decoder", ]: - logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") - HLG = k2.Fsa.from_dict( - torch.load(params.lang_dir + "/HLG.pt", map_location="cpu") - ) + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -386,9 +399,7 @@ def main(): best_path = next(iter(best_path_dict.values())) hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file( - params.lang_dir + "/words.txt" - ) + word_sym_table = k2.SymbolTable.from_file(params.words_file) hyps = [[word_sym_table[i] for i in ids] for ids in hyps] else: raise ValueError(f"Unsupported decoding method: {params.method}")