Update pretrained.py

This commit is contained in:
Mingshuang Luo 2021-10-13 10:23:27 +08:00
parent 524afc02ba
commit 33b88eee2b

View File

@ -55,18 +55,11 @@ def get_parser():
)
parser.add_argument(
"--words-file",
"--lang-dir",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--HLG",
type=str,
required=True,
help="Path to HLG.pt.",
)
help="Path to lang bpe dir.",
)
parser.add_argument(
"--method",
@ -167,13 +160,6 @@ def get_parser():
""",
)
parser.add_argument(
"--lang-dir",
type=str,
required=True,
help="Path to lang bpe dir.",
)
parser.add_argument(
"sound_files",
type=str,
@ -312,7 +298,7 @@ def main():
logging.info("Loading BPE model")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir + "/bpe.model"))
bpe_model.load(params.lang_dir + "/bpe.model")
lattice = get_lattice(
nnet_output=nnet_output,
@ -338,8 +324,8 @@ def main():
"whole-lattice-rescoring",
"attention-decoder",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt")
HLG = k2.Fsa.from_dict(torch.load(params.lang_dir + "/HLG.pt", map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
@ -404,7 +390,7 @@ def main():
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
word_sym_table = k2.SymbolTable.from_file(params.lang_dir + "/words.txt")
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"