mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Zipformer Onnx FP16 (#1671)
Signed-off-by: manickavela29 <manickavela1998@gmail.com>
This commit is contained in:
parent
b594a3875b
commit
eaab2c819f
@ -48,7 +48,8 @@ popd
|
|||||||
--joiner-dim 512 \
|
--joiner-dim 512 \
|
||||||
--causal True \
|
--causal True \
|
||||||
--chunk-size 16 \
|
--chunk-size 16 \
|
||||||
--left-context-frames 128
|
--left-context-frames 128 \
|
||||||
|
--fp16 True
|
||||||
|
|
||||||
The --chunk-size in training is "16,32,64,-1", so we select one of them
|
The --chunk-size in training is "16,32,64,-1", so we select one of them
|
||||||
(excluding -1) during streaming export. The same applies to `--left-context`,
|
(excluding -1) during streaming export. The same applies to `--left-context`,
|
||||||
@ -73,6 +74,7 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from onnxconverter_common import float16
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
@ -154,6 +156,13 @@ def get_parser():
|
|||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to export models in fp16",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -479,7 +488,6 @@ def export_encoder_model_onnx(
|
|||||||
|
|
||||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
def export_decoder_model_onnx(
|
def export_decoder_model_onnx(
|
||||||
decoder_model: OnnxDecoder,
|
decoder_model: OnnxDecoder,
|
||||||
decoder_filename: str,
|
decoder_filename: str,
|
||||||
@ -747,6 +755,24 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"Exported joiner to {joiner_filename}")
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
if(params.fp16) :
|
||||||
|
logging.info("Generate fp16 models")
|
||||||
|
|
||||||
|
encoder = onnx.load(encoder_filename)
|
||||||
|
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
|
||||||
|
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(encoder_fp16,encoder_filename_fp16)
|
||||||
|
|
||||||
|
decoder = onnx.load(decoder_filename)
|
||||||
|
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
|
||||||
|
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(decoder_fp16,decoder_filename_fp16)
|
||||||
|
|
||||||
|
joiner = onnx.load(joiner_filename)
|
||||||
|
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
|
||||||
|
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(joiner_fp16,joiner_filename_fp16)
|
||||||
|
|
||||||
# Generate int8 quantization models
|
# Generate int8 quantization models
|
||||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
@ -48,8 +48,8 @@ popd
|
|||||||
--joiner-dim 512 \
|
--joiner-dim 512 \
|
||||||
--causal False \
|
--causal False \
|
||||||
--chunk-size "16,32,64,-1" \
|
--chunk-size "16,32,64,-1" \
|
||||||
--left-context-frames "64,128,256,-1"
|
--left-context-frames "64,128,256,-1" \
|
||||||
|
--fp16 True
|
||||||
It will generate the following 3 files inside $repo/exp:
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
- encoder-epoch-99-avg-1.onnx
|
- encoder-epoch-99-avg-1.onnx
|
||||||
@ -70,6 +70,7 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from onnxconverter_common import float16
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
@ -151,6 +152,13 @@ def get_parser():
|
|||||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to export models in fp16",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -584,6 +592,24 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"Exported joiner to {joiner_filename}")
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
if(params.fp16) :
|
||||||
|
logging.info("Generate fp16 models")
|
||||||
|
|
||||||
|
encoder = onnx.load(encoder_filename)
|
||||||
|
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
|
||||||
|
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(encoder_fp16,encoder_filename_fp16)
|
||||||
|
|
||||||
|
decoder = onnx.load(decoder_filename)
|
||||||
|
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
|
||||||
|
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(decoder_fp16,decoder_filename_fp16)
|
||||||
|
|
||||||
|
joiner = onnx.load(joiner_filename)
|
||||||
|
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
|
||||||
|
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(joiner_fp16,joiner_filename_fp16)
|
||||||
|
|
||||||
# Generate int8 quantization models
|
# Generate int8 quantization models
|
||||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ onnx>=1.15.0
|
|||||||
onnxruntime>=1.16.3
|
onnxruntime>=1.16.3
|
||||||
onnxoptimizer
|
onnxoptimizer
|
||||||
onnxsim
|
onnxsim
|
||||||
|
onnxconverter_common
|
||||||
|
|
||||||
# style check session:
|
# style check session:
|
||||||
black==22.3.0
|
black==22.3.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user