support --tokens in conformer_ctc/export.py

This commit is contained in:
Fangjun Kuang 2023-10-01 08:23:59 +08:00
parent 5a877da3a0
commit ad4af5c899

View File

@ -23,12 +23,12 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import torch import torch
from conformer import Conformer from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, num_tokens, str2bool
from icefall.utils import AttributeDict, str2bool
def get_parser(): def get_parser():
@ -63,11 +63,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_char", required=True,
help="""It contains language related input files such as "lexicon.txt" help="Path to the tokens.txt.",
""",
) )
parser.add_argument( parser.add_argument(
@ -98,16 +97,16 @@ def get_params() -> AttributeDict:
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
logging.info(params) # Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
lexicon = Lexicon(params.lang_dir) num_classes = num_tokens(token_table) + 1 # +1 for the blank
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank logging.info(params)
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():