Update pretrained.py

This commit is contained in:
Mingshuang Luo 2021-10-14 19:00:25 +08:00
parent 40db9985ec
commit a22b638820

View File

@ -20,7 +20,7 @@
import argparse import argparse
import logging import logging
import math import math
import os from pathlib import Path
from typing import List from typing import List
import k2 import k2
@ -298,8 +298,8 @@ def main():
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
logging.info("Use CTC decoding") logging.info("Use CTC decoding")
if not os.path.exists(params.bpe_model): if not Path(params.bpe_model).exists():
raise ValueError("The path to bpe.model is required!") raise ValueError(f"The path to {params.bpe_model} doesn't exist!")
bpe_model = spm.SentencePieceProcessor() bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model) bpe_model.load(params.bpe_model)
@ -333,10 +333,10 @@ def main():
"whole-lattice-rescoring", "whole-lattice-rescoring",
"attention-decoder", "attention-decoder",
]: ]:
if not os.path.exists(params.HLG): if not Path(params.HLG).exists():
raise ValueError("The path to HLG.pt is required!") raise ValueError(f"The path to {params.HLG} doesn't exist!")
if not os.path.exists(params.words_file): if not Path(params.words_file).exists():
raise ValueError("The path to words.txt is required!") raise ValueError(f"The path to {params.words_file} doesn't exist!")
logging.info(f"Loading HLG from {params.HLG}") logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))