diff --git a/egs/aishell/ASR/conformer_ctc/export.py b/egs/aishell/ASR/conformer_ctc/export.py index 1df3cfdc2..49871d437 100755 --- a/egs/aishell/ASR/conformer_ctc/export.py +++ b/egs/aishell/ASR/conformer_ctc/export.py @@ -23,12 +23,12 @@ import argparse import logging from pathlib import Path +import k2 import torch from conformer import Conformer from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -63,11 +63,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="""It contains language related input files such as "lexicon.txt" - """, + required=True, + help="Path to the tokens.txt.", ) parser.add_argument( @@ -98,16 +97,16 @@ def get_params() -> AttributeDict: def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - logging.info(params) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank + num_classes = num_tokens(token_table) + 1 # +1 for the blank + + logging.info(params) device = torch.device("cpu") if torch.cuda.is_available():