Update export-onnx.py

This commit is contained in:
jinzr 2024-01-27 02:09:16 +08:00
parent d9227665eb
commit c606ef5e50

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./pruned_transducer_stateless2/export-onnx.py \ ./pruned_transducer_stateless2/export-onnx.py \
--lang-dir $repo/data/lang_char \ --tokens $repo/data/lang_char/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--exp-dir $repo/exp --exp-dir $repo/exp
@ -48,6 +48,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple from typing import Dict, Tuple
import k2
import onnx import onnx
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -57,14 +58,8 @@ from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import average_checkpoints, load_checkpoint
average_checkpoints, from icefall.utils import num_tokens, setup_logger, str2bool
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import setup_logger, str2bool
def get_parser(): def get_parser():
@ -110,10 +105,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_char", default="data/lang_char/tokens.txt",
help="The lang dir", help="Path to the tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -397,9 +392,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir) token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = 0 params.blank_id = token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = num_tokens(token_table) + 1
logging.info(params) logging.info(params)