mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
Update export-onnx.py
This commit is contained in:
parent
d9227665eb
commit
c606ef5e50
@ -28,7 +28,7 @@ popd
|
||||
2. Export the model to ONNX
|
||||
|
||||
./pruned_transducer_stateless2/export-onnx.py \
|
||||
--lang-dir $repo/data/lang_char \
|
||||
--tokens $repo/data/lang_char/tokens.txt \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp
|
||||
@ -48,6 +48,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import k2
|
||||
import onnx
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -57,14 +58,8 @@ from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import setup_logger, str2bool
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -110,10 +105,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_char",
|
||||
help="The lang dir",
|
||||
default="data/lang_char/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -397,9 +392,9 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user