Fix export for stateless5

This commit is contained in:
Fangjun Kuang 2024-01-26 17:35:32 +08:00
parent ed68914fe2
commit d64cdf6d4f

View File

@ -20,7 +20,7 @@
Usage for offline: Usage for offline:
./pruned_transducer_stateless5/export.py \ ./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \ --exp-dir ./pruned_transducer_stateless5/exp_L_offline \
--lang-dir data/lang_char \ --tokens data/lang_char/tokens.txt \
--epoch 4 \ --epoch 4 \
--avg 1 --avg 1
@ -28,7 +28,7 @@ It will generate a file exp_dir/pretrained.pt for offline ASR.
./pruned_transducer_stateless5/export.py \ ./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_offline \ --exp-dir ./pruned_transducer_stateless5/exp_L_offline \
--lang-dir data/lang_char \ --tokens data/lang_char/tokens.txt \
--epoch 4 \ --epoch 4 \
--avg 1 \ --avg 1 \
--jit True --jit True
@ -38,7 +38,7 @@ It will generate a file exp_dir/cpu_jit.pt for offline ASR.
Usage for streaming: Usage for streaming:
./pruned_transducer_stateless5/export.py \ ./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
--lang-dir data/lang_char \ --tokens data/lang_char/tokens.txt \
--epoch 7 \ --epoch 7 \
--avg 1 --avg 1
@ -46,7 +46,7 @@ It will generate a file exp_dir/pretrained.pt for streaming ASR.
./pruned_transducer_stateless5/export.py \ ./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless5/exp_L_streaming \ --exp-dir ./pruned_transducer_stateless5/exp_L_streaming \
--lang-dir data/lang_char \ --tokens data/lang_char/tokens.txt \
--epoch 7 \ --epoch 7 \
--avg 1 \ --avg 1 \
--jit True --jit True
@ -73,13 +73,13 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon from icefall.utils import num_tokens, str2bool
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -114,10 +114,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_char", default="data/lang_char/tokens.txt",
help="The lang dir", help="Path to the tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -152,10 +152,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir) token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.blank_id = 0 params.vocab_size = num_tokens(token_table) + 1
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params) logging.info(params)