mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update pretrained.py
This commit is contained in:
parent
524afc02ba
commit
33b88eee2b
@ -55,18 +55,11 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to words.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to HLG.pt.",
|
||||
)
|
||||
help="Path to lang bpe dir.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
@ -167,13 +160,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to lang bpe dir.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
@ -312,7 +298,7 @@ def main():
|
||||
|
||||
logging.info("Loading BPE model")
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir + "/bpe.model"))
|
||||
bpe_model.load(params.lang_dir + "/bpe.model")
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
@ -338,8 +324,8 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.lang_dir + "/HLG.pt", map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
@ -404,7 +390,7 @@ def main():
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.lang_dir + "/words.txt")
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
|
||||
s = "\n"
|
||||
|
Loading…
x
Reference in New Issue
Block a user