diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py index bc33dd160..0f6190a41 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py @@ -23,7 +23,7 @@ Usage: ./pruned_transducer_stateless5/export.py \ --exp-dir ./pruned_transducer_stateless5/exp \ - --lang-dir ./data/lang_char \ + --tokens ./data/lang_char/tokens.txt \ --epoch 30 \ --avg 24 \ --use-averaged-model True @@ -50,8 +50,9 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -60,8 +61,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -118,13 +118,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -160,13 +157,14 @@ def main(): logging.info(f"device: {device}") - bpe_model = params.lang_dir + "/bpe.model" - sp = spm.SentencePieceProcessor() - sp.load(bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + # Load id of the token and the vocab size + # is defined in local/train_bpe_model.py + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -256,6 +254,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file