mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
Update pretrained.py
This commit is contained in:
parent
524afc02ba
commit
33b88eee2b
@ -55,17 +55,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--words-file",
|
"--lang-dir",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Path to words.txt",
|
help="Path to lang bpe dir.",
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--HLG",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to HLG.pt.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -167,13 +160,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to lang bpe dir.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"sound_files",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
@ -312,7 +298,7 @@ def main():
|
|||||||
|
|
||||||
logging.info("Loading BPE model")
|
logging.info("Loading BPE model")
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
bpe_model.load(str(params.lang_dir + "/bpe.model"))
|
bpe_model.load(params.lang_dir + "/bpe.model")
|
||||||
|
|
||||||
lattice = get_lattice(
|
lattice = get_lattice(
|
||||||
nnet_output=nnet_output,
|
nnet_output=nnet_output,
|
||||||
@ -338,8 +324,8 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
]:
|
]:
|
||||||
logging.info(f"Loading HLG from {params.HLG}")
|
logging.info(f"Loading HLG from {params.lang_dir}/HLG.pt")
|
||||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
HLG = k2.Fsa.from_dict(torch.load(params.lang_dir + "/HLG.pt", map_location="cpu"))
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
if not hasattr(HLG, "lm_scores"):
|
if not hasattr(HLG, "lm_scores"):
|
||||||
# For whole-lattice-rescoring and attention-decoder
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
@ -404,7 +390,7 @@ def main():
|
|||||||
best_path = next(iter(best_path_dict.values()))
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.lang_dir + "/words.txt")
|
||||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user