diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index 6691c88b7..b32609122 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -343,7 +343,6 @@ def export_encoder_model_onnx( encoder_filename: str, opset_version: int = 11, feature_dim: int = 80, - fp16: bool = False, ) -> None: encoder_model.encoder.__class__.forward = ( encoder_model.encoder.__class__.streaming_forward @@ -489,12 +488,6 @@ def export_encoder_model_onnx( add_meta_data(filename=encoder_filename, meta_data=meta_data) - if(fp16) : - logging.info("Exporting Encoder model in fp16") - encoder = onnx.load(encoder_filename) - encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) - onnx.save(encoder_fp16,encoder_filename) - def export_decoder_model_onnx( decoder_model: OnnxDecoder, decoder_filename: str, @@ -741,7 +734,6 @@ def main(): encoder_filename, opset_version=opset_version, feature_dim=params.feature_dim, - fp16=params.fp16, ) logging.info(f"Exported encoder to {encoder_filename}") @@ -766,8 +758,27 @@ def main(): # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - logging.info("Generate int8 quantization models") + if(params.fp16) : + logging.info("Exporting models in fp16") + + encoder = onnx.load(encoder_filename) + encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True) + encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx" + onnx.save(encoder_fp16,encoder_filename_fp16) + + decoder = onnx.load(decoder_filename) + decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True) + decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx" + onnx.save(decoder_fp16,decoder_filename_fp16) + + joiner = onnx.load(joiner_filename) + joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True) + joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx" + onnx.save(joiner_fp16,joiner_filename_fp16) + + logging.info("Generate int8 quantization models") + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" quantize_dynamic( model_input=encoder_filename,