From 40db9985ec39d44221e73b0d24ee1dd5eaa3e6c9 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Thu, 14 Oct 2021 18:27:15 +0800 Subject: [PATCH] Update ctc-decoding on pretrained.py and conformer_ctc.rst --- .../recipes/librispeech/conformer_ctc.rst | 11 +++-- .../ASR/conformer_ctc/pretrained.py | 46 ++++++++++++------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/docs/source/recipes/librispeech/conformer_ctc.rst b/docs/source/recipes/librispeech/conformer_ctc.rst index 45ad79313..b76df9475 100644 --- a/docs/source/recipes/librispeech/conformer_ctc.rst +++ b/docs/source/recipes/librispeech/conformer_ctc.rst @@ -551,7 +551,7 @@ The command to run CTC decoding is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --bpe-model ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/bpe.model \ --method ctc-decoding \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ @@ -595,7 +595,8 @@ The command to run HLG decoding is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1089-134686-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0001.flac \ ./tmp/icefall_asr_librispeech_conformer_ctc/test_wavs/1221-135766-0002.flac @@ -637,7 +638,8 @@ The command to run HLG decoding + LM rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ --method whole-lattice-rescoring \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 0.8 \ @@ -684,7 +686,8 @@ The command to run HLG decoding + LM rescoring + attention decoder rescoring is: $ cd egs/librispeech/ASR $ ./conformer_ctc/pretrained.py \ --checkpoint ./tmp/icefall_asr_librispeech_conformer_ctc/exp/pretrained.pt \ - --lang-dir ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe \ + --words-file ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/words.txt \ + --HLG ./tmp/icefall_asr_librispeech_conformer_ctc/data/lang_bpe/HLG.pt \ --method attention-decoder \ --G ./tmp/icefall_asr_librispeech_conformer_ctc/data/lm/G_4_gram.pt \ --ngram-lm-scale 1.3 \ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index be94e6875..b4142ab89 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -20,6 +20,7 @@ import argparse import logging import math +import os from typing import List import k2 @@ -36,7 +37,6 @@ from icefall.decode import ( rescore_with_attention_decoder, rescore_with_whole_lattice, ) -from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, get_texts @@ -55,10 +55,21 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--words-file", type=str, - required=True, - help="Path to lang dir.", + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", + type=str, + help="Path to HLG.pt.", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="Path to bpe.model.", ) parser.add_argument( @@ -287,17 +298,19 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) + if not os.path.exists(params.bpe_model): + raise ValueError("The path to bpe.model is required!") + + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = bpe_model.get_piece_size() - 1 + H = k2.ctc_topo( max_token=max_token_id, modified=False, device=device, ) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.lang_dir + "/bpe.model") - lattice = get_lattice( nnet_output=nnet_output, decoding_graph=H, @@ -320,10 +333,13 @@ def main(): "whole-lattice-rescoring", "attention-decoder", ]: - logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") - HLG = k2.Fsa.from_dict( - torch.load(params.lang_dir + "/HLG.pt", map_location="cpu") - ) + if not os.path.exists(params.HLG): + raise ValueError("The path to HLG.pt is required!") + if not os.path.exists(params.words_file): + raise ValueError("The path to words.txt is required!") + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -386,9 +402,7 @@ def main(): best_path = next(iter(best_path_dict.values())) hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file( - params.lang_dir + "/words.txt" - ) + word_sym_table = k2.SymbolTable.from_file(params.words_file) hyps = [[word_sym_table[i] for i in ids] for ids in hyps] else: raise ValueError(f"Unsupported decoding method: {params.method}")