mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update pretrained.py
This commit is contained in:
parent
40db9985ec
commit
a22b638820
@ -20,7 +20,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -298,8 +298,8 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
if not os.path.exists(params.bpe_model):
|
if not Path(params.bpe_model).exists():
|
||||||
raise ValueError("The path to bpe.model is required!")
|
raise ValueError(f"The path to {params.bpe_model} doesn't exist!")
|
||||||
|
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
bpe_model.load(params.bpe_model)
|
bpe_model.load(params.bpe_model)
|
||||||
@ -333,10 +333,10 @@ def main():
|
|||||||
"whole-lattice-rescoring",
|
"whole-lattice-rescoring",
|
||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
]:
|
]:
|
||||||
if not os.path.exists(params.HLG):
|
if not Path(params.HLG).exists():
|
||||||
raise ValueError("The path to HLG.pt is required!")
|
raise ValueError(f"The path to {params.HLG} doesn't exist!")
|
||||||
if not os.path.exists(params.words_file):
|
if not Path(params.words_file).exists():
|
||||||
raise ValueError("The path to words.txt is required!")
|
raise ValueError(f"The path to {params.words_file} doesn't exist!")
|
||||||
|
|
||||||
logging.info(f"Loading HLG from {params.HLG}")
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user