mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
extending to export-onnx.py
Signed-off-by: manickavela29 <manickavela1998@gmail.com>
This commit is contained in:
parent
fa235adba2
commit
683ae6c2cc
@ -755,12 +755,8 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"Exported joiner to {joiner_filename}")
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
# Generate int8 quantization models
|
|
||||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
|
||||||
|
|
||||||
|
|
||||||
if(params.fp16) :
|
if(params.fp16) :
|
||||||
logging.info("Exporting models in fp16")
|
logging.info("Generate fp16 models")
|
||||||
|
|
||||||
encoder = onnx.load(encoder_filename)
|
encoder = onnx.load(encoder_filename)
|
||||||
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
|
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
|
||||||
@ -777,6 +773,9 @@ def main():
|
|||||||
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
|
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
|
||||||
onnx.save(joiner_fp16,joiner_filename_fp16)
|
onnx.save(joiner_fp16,joiner_filename_fp16)
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
logging.info("Generate int8 quantization models")
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
@ -282,7 +282,6 @@ def export_encoder_model_onnx(
|
|||||||
encoder_model: OnnxEncoder,
|
encoder_model: OnnxEncoder,
|
||||||
encoder_filename: str,
|
encoder_filename: str,
|
||||||
opset_version: int = 11,
|
opset_version: int = 11,
|
||||||
fp16:bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Export the given encoder model to ONNX format.
|
"""Export the given encoder model to ONNX format.
|
||||||
The exported model has two inputs:
|
The exported model has two inputs:
|
||||||
@ -334,12 +333,6 @@ def export_encoder_model_onnx(
|
|||||||
|
|
||||||
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
if(fp16) :
|
|
||||||
logging.info("Exporting Encoder model in fp16")
|
|
||||||
encoder = onnx.load(encoder_filename)
|
|
||||||
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
|
|
||||||
onnx.save(encoder_fp16,encoder_filename)
|
|
||||||
|
|
||||||
|
|
||||||
def export_decoder_model_onnx(
|
def export_decoder_model_onnx(
|
||||||
decoder_model: OnnxDecoder,
|
decoder_model: OnnxDecoder,
|
||||||
@ -578,7 +571,6 @@ def main():
|
|||||||
encoder,
|
encoder,
|
||||||
encoder_filename,
|
encoder_filename,
|
||||||
opset_version=opset_version,
|
opset_version=opset_version,
|
||||||
fp16=params.fp16,
|
|
||||||
)
|
)
|
||||||
logging.info(f"Exported encoder to {encoder_filename}")
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
@ -600,6 +592,24 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"Exported joiner to {joiner_filename}")
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
if(params.fp16) :
|
||||||
|
logging.info("Generate fp16 models")
|
||||||
|
|
||||||
|
encoder = onnx.load(encoder_filename)
|
||||||
|
encoder_fp16 = float16.convert_float_to_float16(encoder, keep_io_types=True)
|
||||||
|
encoder_filename_fp16 = params.exp_dir / f"encoder-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(encoder_fp16,encoder_filename_fp16)
|
||||||
|
|
||||||
|
decoder = onnx.load(decoder_filename)
|
||||||
|
decoder_fp16 = float16.convert_float_to_float16(decoder, keep_io_types=True)
|
||||||
|
decoder_filename_fp16 = params.exp_dir / f"decoder-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(decoder_fp16,decoder_filename_fp16)
|
||||||
|
|
||||||
|
joiner = onnx.load(joiner_filename)
|
||||||
|
joiner_fp16 = float16.convert_float_to_float16(joiner, keep_io_types=True)
|
||||||
|
joiner_filename_fp16 = params.exp_dir / f"joiner-{suffix}.fp16.onnx"
|
||||||
|
onnx.save(joiner_fp16,joiner_filename_fp16)
|
||||||
|
|
||||||
# Generate int8 quantization models
|
# Generate int8 quantization models
|
||||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user