diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/export.py b/egs/librispeech/ASR/tiny_transducer_ctc/export.py index 4117f7244..334dd011e 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/export.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/export.py @@ -76,17 +76,17 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch +from train import add_model_arguments, get_params, get_transducer_model + from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) -from icefall.lexicon import UniqLexicon -from icefall.utils import str2bool -from train import add_model_arguments, get_params, get_transducer_model +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -143,13 +143,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -189,17 +186,9 @@ def main(): logging.info(f"device: {device}") - if "lang_bpe" in str(params.lang_dir): - sp = spm.SentencePieceProcessor() - sp.load(params.lang_dir + "/bpe.model") - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - else: - assert "lang_phone" in str(params.lang_dir) - phone_lexicon = UniqLexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(phone_lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/zipformer_ctc/export.py b/egs/librispeech/ASR/zipformer_ctc/export.py index 0ff50f128..4c46aea2c 100755 --- a/egs/librispeech/ASR/zipformer_ctc/export.py +++ b/egs/librispeech/ASR/zipformer_ctc/export.py @@ -23,6 +23,7 @@ import argparse import logging from pathlib import Path +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -33,8 +34,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -90,11 +90,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_bpe_500", - help="""It contains language related input files such as "lexicon.txt" - """, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -113,17 +112,15 @@ def get_parser(): def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) params = get_params() params.update(vars(args)) - logging.info(params) + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - params.vocab_size = num_classes + logging.info(params) device = torch.device("cpu") if torch.cuda.is_available():