Add export-onnx.py for fluent_speech_commands

This commit is contained in:
Fangjun Kuang 2025-07-16 18:17:38 +08:00
parent e22bc78f98
commit 1e792e4fea

View File

@ -0,0 +1,339 @@
#!/usr/bin/env python3
#
# Copyright 2025 Xiaomi Corporation (Author: Fangjun Kuang)
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import torch
from decode import get_params, get_parser, get_transducer_model
from onnxruntime.quantization import QuantType, quantize_dynamic
from torch import nn
from transducer.conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Conformer"""
def __init__(self, encoder: Conformer):
"""
Args:
encoder:
A Conformer encoder.
"""
super().__init__()
self.encoder = encoder
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Conformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', C)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 23, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
meta_data = {
"model_type": "conformer",
"version": "1",
"model_author": "k2-fsa",
"comment": "SLU_transducer",
"note": "The decoder is an LSTM with states",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has 3 inputs:
- y: a torch.int64 tensor of shape (N, 1)
- h: a float32 tensor of shape (num_layers, N, hidden_dim)
- c: a float32 tensor of shape (num_layers, N, hidden_dim)
and has 3 outputs:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(1, 1, dtype=torch.int64)
_, (h, c) = decoder_model(y)
torch.onnx.export(
decoder_model,
(y, (h, c)),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "h", "c"],
output_names=["decoder_out", "next_h", "next_c"],
dynamic_axes={
"y": {0: "N"},
"h": {1: "N"},
"c": {1: "N"},
},
)
meta_data = {
"num_layers": h.shape[0],
"hidden_dim": h.shape[2],
}
print("decoder meta_data", meta_data)
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
hidden_dim: int,
vocab_size: int,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported model has 2 inputs:
- encoder_out: a float32 tensor of shape (N, 1, enc_dim)
- decoder_out: a float32 tensor of shape (N, 1, dec_dim)
and has 1 output:
- logits: a torch.float32 tensor of shape (N, 1, vocab_size)
Args:
joiner_model:
The decoder model to be exported.
joiner_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
enc_out = torch.zeros(1, 1, hidden_dim, dtype=torch.float32)
dec_out = torch.zeros(1, 1, hidden_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(enc_out, dec_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out"],
output_names=["logits"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"vocab_size": vocab_size,
"hidden_dim": hidden_dim,
}
print("joiner meta_data", meta_data)
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params["env_info"] = get_env_info()
device = torch.device("cpu")
logging.info(f"device: {device}")
model = get_transducer_model(params)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
model.device = device
encoder_num_param = sum([p.numel() for p in model.encoder.parameters()])
decoder_num_param = sum([p.numel() for p in model.decoder.parameters()])
joiner_num_param = sum([p.numel() for p in model.joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
encoder = OnnxEncoder(
encoder=model.encoder,
)
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / "encoder.onnx"
# export_encoder_model_onnx(
# encoder,
# encoder_filename,
# opset_version=opset_version,
# )
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / "decoder.onnx"
# export_decoder_model_onnx(
# model.decoder,
# decoder_filename,
# opset_version=opset_version,
# )
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
hidden_dim=params.hidden_dim,
vocab_size=params.vocab_size,
)
logging.info(f"Exported decoder to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / "encoder.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / "decoder.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul", "Gather"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / "joiner.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()