Update pretrained.py

This commit is contained in:
Mingshuang Luo 2021-10-13 10:23:27 +08:00
parent 524afc02ba
commit 33b88eee2b

View File

@ -55,18 +55,11 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--words-file", "--lang-dir",
type=str, type=str,
required=True, required=True,
help="Path to words.txt", help="Path to lang bpe dir.",
) )
parser.add_argument(
"--HLG",
type=str,
required=True,
help="Path to HLG.pt.",
)
parser.add_argument( parser.add_argument(
"--method", "--method",
@ -167,13 +160,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--lang-dir",
type=str,
required=True,
help="Path to lang bpe dir.",
)
parser.add_argument( parser.add_argument(
"sound_files", "sound_files",
type=str, type=str,
@ -312,7 +298,7 @@ def main():
logging.info("Loading BPE model") logging.info("Loading BPE model")
bpe_model = spm.SentencePieceProcessor() bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir + "/bpe.model")) bpe_model.load(params.lang_dir + "/bpe.model")
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
@ -338,8 +324,8 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
]: ]:
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(torch.load(params.lang_dir + "/HLG.pt", map_location="cpu"))
HLG = HLG.to(device) HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder # For whole-lattice-rescoring and attention-decoder
@ -404,7 +390,7 @@ def main():
best_path = next(iter(best_path_dict.values())) best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path) hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file) word_sym_table = k2.SymbolTable.from_file(params.lang_dir + "/words.txt")
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n" s = "\n"