From 5934b37e3fc0a300ad3d231573adb1b59d481270 Mon Sep 17 00:00:00 2001 From: manickavela29 Date: Wed, 26 Jun 2024 18:35:28 +0000 Subject: [PATCH] Zipformer Onnx fp16 --- .../ASR/zipformer/export-onnx-streaming.py | 18 ++++++++++++++++- egs/librispeech/ASR/zipformer/export-onnx.py | 20 +++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 5d0c9ea43..bdecee726 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -48,7 +48,8 @@ popd --joiner-dim 512 \ --causal True \ --chunk-size 16 \ - --left-context-frames 128 + --left-context-frames 128 \ + --fp16 True The --chunk-size in training is "16,32,64,-1", so we select one of them (excluding -1) during streaming export. The same applies to `--left-context`, @@ -74,6 +75,7 @@ import torch import torch.nn as nn from decoder import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic +from onnxconverter_common import float16 from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params from zipformer import Zipformer2 @@ -154,6 +156,13 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + add_model_arguments(parser) return parser @@ -334,6 +343,7 @@ def export_encoder_model_onnx( encoder_filename: str, opset_version: int = 11, feature_dim: int = 80, + fp16: bool = False, ) -> None: encoder_model.encoder.__class__.forward = ( encoder_model.encoder.__class__.streaming_forward @@ -479,6 +489,11 @@ def export_encoder_model_onnx( add_meta_data(filename=encoder_filename, meta_data=meta_data) + if(fp16) : + logging.info("Exporting Encoder model in fp16") + encoder = onnx.load(encoder_filename) + encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) + onnx.save(encoder_fp16,encoder_filename) def export_decoder_model_onnx( decoder_model: OnnxDecoder, @@ -726,6 +741,7 @@ def main(): encoder_filename, opset_version=opset_version, feature_dim=params.feature_dim, + fp16=params.fp16, ) logging.info(f"Exported encoder to {encoder_filename}") diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 3682f0b62..fe49da101 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -48,8 +48,8 @@ popd --joiner-dim 512 \ --causal False \ --chunk-size "16,32,64,-1" \ - --left-context-frames "64,128,256,-1" - + --left-context-frames "64,128,256,-1" \ + --fp16 True It will generate the following 3 files inside $repo/exp: - encoder-epoch-99-avg-1.onnx @@ -71,6 +71,7 @@ import torch import torch.nn as nn from decoder import Decoder from onnxruntime.quantization import QuantType, quantize_dynamic +from onnxconverter_common import float16 from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params from zipformer import Zipformer2 @@ -151,6 +152,13 @@ def get_parser(): help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( + "--fp16", + type=str2bool, + default=False, + help="Whether to export models in fp16", + ) + add_model_arguments(parser) return parser @@ -274,6 +282,7 @@ def export_encoder_model_onnx( encoder_model: OnnxEncoder, encoder_filename: str, opset_version: int = 11, + fp16:bool = False, ) -> None: """Export the given encoder model to ONNX format. The exported model has two inputs: @@ -325,6 +334,12 @@ def export_encoder_model_onnx( add_meta_data(filename=encoder_filename, meta_data=meta_data) + if(fp16) : + logging.info("Exporting Encoder model in fp16") + encoder = onnx.load(encoder_filename) + encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) + onnx.save(encoder_fp16,encoder_filename) + def export_decoder_model_onnx( decoder_model: OnnxDecoder, @@ -563,6 +578,7 @@ def main(): encoder, encoder_filename, opset_version=opset_version, + fp16=params.fp16, ) logging.info(f"Exported encoder to {encoder_filename}")