diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index a0332dd2a..3ed2d74e5 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -55,18 +55,11 @@ def get_parser(): ) parser.add_argument( - "--words-file", + "--lang-dir", type=str, required=True, - help="Path to words.txt", - ) - - parser.add_argument( - "--HLG", - type=str, - required=True, - help="Path to HLG.pt.", - ) + help="Path to lang bpe dir.", + ) parser.add_argument( "--method", @@ -167,13 +160,6 @@ def get_parser(): """, ) - parser.add_argument( - "--lang-dir", - type=str, - required=True, - help="Path to lang bpe dir.", - ) - parser.add_argument( "sound_files", type=str, @@ -312,7 +298,7 @@ def main(): logging.info("Loading BPE model") bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir + "/bpe.model")) + bpe_model.load(params.lang_dir + "/bpe.model") lattice = get_lattice( nnet_output=nnet_output, @@ -338,8 +324,8 @@ def main(): "whole-lattice-rescoring", "attention-decoder", ]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt") + HLG = k2.Fsa.from_dict(torch.load(params.lang_dir + "/HLG.pt", map_location="cpu")) HLG = HLG.to(device) if not hasattr(HLG, "lm_scores"): # For whole-lattice-rescoring and attention-decoder @@ -404,7 +390,7 @@ def main(): best_path = next(iter(best_path_dict.values())) hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) + word_sym_table = k2.SymbolTable.from_file(params.lang_dir + "/words.txt") hyps = [[word_sym_table[i] for i in ids] for ids in hyps] s = "\n"