From c606ef5e5083b15171c4e78aa0c9f29f8137eaf5 Mon Sep 17 00:00:00 2001 From: jinzr Date: Sat, 27 Jan 2024 02:09:16 +0800 Subject: [PATCH] Update export-onnx.py --- .../export-onnx.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py index 140b1d37f..8aea79fe3 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export-onnx.py @@ -28,7 +28,7 @@ popd 2. Export the model to ONNX ./pruned_transducer_stateless2/export-onnx.py \ - --lang-dir $repo/data/lang_char \ + --tokens $repo/data/lang_char/tokens.txt \ --epoch 99 \ --avg 1 \ --exp-dir $repo/exp @@ -48,6 +48,7 @@ import logging from pathlib import Path from typing import Dict, Tuple +import k2 import onnx import torch import torch.nn as nn @@ -57,14 +58,8 @@ from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import get_params, get_transducer_model -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import num_tokens, setup_logger, str2bool def get_parser(): @@ -110,10 +105,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=str, - default="data/lang_char", - help="The lang dir", + default="data/lang_char/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -397,9 +392,9 @@ def main(): logging.info(f"device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params)