mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 02:06:13 +00:00
Zipformer Onnx fp16
This commit is contained in:
parent
b594a3875b
commit
5934b37e3f
@ -48,7 +48,8 @@ popd
|
||||
--joiner-dim 512 \
|
||||
--causal True \
|
||||
--chunk-size 16 \
|
||||
--left-context-frames 128
|
||||
--left-context-frames 128 \
|
||||
--fp16 True
|
||||
|
||||
The --chunk-size in training is "16,32,64,-1", so we select one of them
|
||||
(excluding -1) during streaming export. The same applies to `--left-context`,
|
||||
@ -74,6 +75,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from onnxconverter_common import float16
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from zipformer import Zipformer2
|
||||
@ -154,6 +156,13 @@ def get_parser():
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to export models in fp16",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -334,6 +343,7 @@ def export_encoder_model_onnx(
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
feature_dim: int = 80,
|
||||
fp16: bool = False,
|
||||
) -> None:
|
||||
encoder_model.encoder.__class__.forward = (
|
||||
encoder_model.encoder.__class__.streaming_forward
|
||||
@ -479,6 +489,11 @@ def export_encoder_model_onnx(
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
if(fp16) :
|
||||
logging.info("Exporting Encoder model in fp16")
|
||||
encoder = onnx.load(encoder_filename)
|
||||
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
|
||||
onnx.save(encoder_fp16,encoder_filename)
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
@ -726,6 +741,7 @@ def main():
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
feature_dim=params.feature_dim,
|
||||
fp16=params.fp16,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
|
@ -48,8 +48,8 @@ popd
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1"
|
||||
|
||||
--left-context-frames "64,128,256,-1" \
|
||||
--fp16 True
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
- encoder-epoch-99-avg-1.onnx
|
||||
@ -71,6 +71,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from decoder import Decoder
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from onnxconverter_common import float16
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
from zipformer import Zipformer2
|
||||
@ -151,6 +152,13 @@ def get_parser():
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to export models in fp16",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -274,6 +282,7 @@ def export_encoder_model_onnx(
|
||||
encoder_model: OnnxEncoder,
|
||||
encoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
fp16:bool = False,
|
||||
) -> None:
|
||||
"""Export the given encoder model to ONNX format.
|
||||
The exported model has two inputs:
|
||||
@ -325,6 +334,12 @@ def export_encoder_model_onnx(
|
||||
|
||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||
|
||||
if(fp16) :
|
||||
logging.info("Exporting Encoder model in fp16")
|
||||
encoder = onnx.load(encoder_filename)
|
||||
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
|
||||
onnx.save(encoder_fp16,encoder_filename)
|
||||
|
||||
|
||||
def export_decoder_model_onnx(
|
||||
decoder_model: OnnxDecoder,
|
||||
@ -563,6 +578,7 @@ def main():
|
||||
encoder,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
fp16=params.fp16,
|
||||
)
|
||||
logging.info(f"Exported encoder to {encoder_filename}")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user