mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-09 14:05:33 +00:00
Support exporting BS Zipformer models to ONNX, used in Triton Server (#1008)
* Support export BS Zipformer models to ONNX in Tritron * Update copyright * Update exporting codes for BS zipformer models * Code format * Update comments * Update export_onnx.py --------- Co-authored-by: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
This commit is contained in:
parent
05e7435d0d
commit
78b9dcc936
@ -2,6 +2,7 @@
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Yifan Yang)
|
||||
# 2023 NVIDIA Corporation (Author: Wen Ding)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -29,7 +30,8 @@ Usage:
|
||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 13
|
||||
--avg 13 \
|
||||
--onnx 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
Check `onnx_check.py` for how to use them.
|
||||
@ -41,6 +43,25 @@ Check `onnx_check.py` for how to use them.
|
||||
- joiner_decoder_proj.onnx
|
||||
- lconv.onnx
|
||||
- frame_reducer.onnx
|
||||
- ctc_output.onnx
|
||||
|
||||
(2) Export to ONNX format which can be used in Triton Server
|
||||
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
|
||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 13 \
|
||||
--onnx-triton 1
|
||||
|
||||
It will generate the following files in the given `exp_dir`.
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
- lconv.onnx
|
||||
- ctc_output.onnx
|
||||
|
||||
Please see ./onnx_pretrained.py for usage of the generated files
|
||||
|
||||
@ -78,6 +99,7 @@ from icefall.checkpoint import (
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import str2bool
|
||||
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -143,9 +165,10 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--onnx",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
default=False,
|
||||
help="""If True, --jit is ignored and it exports the model
|
||||
to onnx format. It will generate the following files:
|
||||
to onnx format.
|
||||
It will generate the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
@ -154,10 +177,28 @@ def get_parser():
|
||||
- joiner_decoder_proj.onnx
|
||||
- lconv.onnx
|
||||
- frame_reducer.onnx
|
||||
- ctc_output.onnx
|
||||
|
||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--onnx-triton",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If True, and it exports the model
|
||||
to onnx format which can be used in NVIDIA triton server.
|
||||
It will generate the following files:
|
||||
|
||||
- encoder.onnx
|
||||
- decoder.onnx
|
||||
- joiner.onnx
|
||||
- joiner_encoder_proj.onnx
|
||||
- joiner_decoder_proj.onnx
|
||||
- lconv.onnx
|
||||
- ctc_output.onnx
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
@ -273,6 +314,44 @@ def export_decoder_model_onnx(
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_decoder_model_onnx_triton(
|
||||
decoder_model: nn.Module,
|
||||
decoder_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the decoder model to ONNX-Triton format.
|
||||
The exported model has one input:
|
||||
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||
and has one output:
|
||||
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||
Args:
|
||||
decoder_model:
|
||||
The decoder model to be exported.
|
||||
decoder_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||
|
||||
decoder_model = TritonOnnxDecoder(decoder_model)
|
||||
|
||||
torch.onnx.export(
|
||||
decoder_model,
|
||||
(y),
|
||||
decoder_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["y"],
|
||||
output_names=["decoder_out"],
|
||||
dynamic_axes={
|
||||
"y": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
@ -369,6 +448,91 @@ def export_joiner_model_onnx(
|
||||
logging.info(f"Saved to {decoder_proj_filename}")
|
||||
|
||||
|
||||
def export_joiner_model_onnx_triton(
|
||||
joiner_model: nn.Module,
|
||||
joiner_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the joiner model to ONNX format.
|
||||
The exported joiner model has two inputs:
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
and produces one output:
|
||||
- logit: a tensor of shape (N, vocab_size)
|
||||
The exported encoder_proj model has one input:
|
||||
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||
and produces one output:
|
||||
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||
The exported decoder_proj model has one input:
|
||||
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||
and produces one output:
|
||||
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||
"""
|
||||
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||
|
||||
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||
|
||||
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||
|
||||
# Note: It uses torch.jit.trace() internally
|
||||
joiner_model = TritonOnnxJoiner(joiner_model)
|
||||
torch.onnx.export(
|
||||
joiner_model,
|
||||
(projected_encoder_out, projected_decoder_out),
|
||||
joiner_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"encoder_out",
|
||||
"decoder_out",
|
||||
"project_input",
|
||||
],
|
||||
output_names=["logit"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"decoder_out": {0: "N"},
|
||||
"logit": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {joiner_filename}")
|
||||
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.encoder_proj,
|
||||
encoder_out,
|
||||
encoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["encoder_out"],
|
||||
output_names=["projected_encoder_out"],
|
||||
dynamic_axes={
|
||||
"encoder_out": {0: "N"},
|
||||
"projected_encoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {encoder_proj_filename}")
|
||||
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
joiner_model.decoder_proj,
|
||||
decoder_out,
|
||||
decoder_proj_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["decoder_out"],
|
||||
output_names=["projected_decoder_out"],
|
||||
dynamic_axes={
|
||||
"decoder_out": {0: "N"},
|
||||
"projected_decoder_out": {0: "N"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {decoder_proj_filename}")
|
||||
|
||||
|
||||
def export_lconv_onnx(
|
||||
lconv: nn.Module,
|
||||
lconv_filename: str,
|
||||
@ -413,6 +577,52 @@ def export_lconv_onnx(
|
||||
logging.info(f"Saved to {lconv_filename}")
|
||||
|
||||
|
||||
def export_lconv_onnx_triton(
|
||||
lconv: nn.Module,
|
||||
lconv_filename: str,
|
||||
opset_version: int = 11,
|
||||
) -> None:
|
||||
"""Export the lconv to ONNX format.
|
||||
|
||||
The exported lconv has two inputs:
|
||||
|
||||
- lconv_input: a tensor of shape (N, T, C)
|
||||
- lconv_input_lens: a tensor of shape (N, )
|
||||
|
||||
and has one output:
|
||||
|
||||
- lconv_out: a tensor of shape (N, T, C)
|
||||
|
||||
Args:
|
||||
lconv:
|
||||
The lconv to be exported.
|
||||
lconv_filename:
|
||||
Filename to save the exported ONNX model.
|
||||
opset_version:
|
||||
The opset version to use.
|
||||
"""
|
||||
lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32)
|
||||
lconv_input_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
||||
|
||||
lconv = TritonOnnxLconv(lconv)
|
||||
|
||||
torch.onnx.export(
|
||||
lconv,
|
||||
(lconv_input, lconv_input_lens),
|
||||
lconv_filename,
|
||||
verbose=False,
|
||||
opset_version=opset_version,
|
||||
input_names=["lconv_input", "lconv_input_lens"],
|
||||
output_names=["lconv_out"],
|
||||
dynamic_axes={
|
||||
"lconv_input": {0: "N", 1: "T"},
|
||||
"lconv_input_lens": {0: "N"},
|
||||
"lconv_out": {0: "N", 1: "T"},
|
||||
},
|
||||
)
|
||||
logging.info(f"Saved to {lconv_filename}")
|
||||
|
||||
|
||||
def export_frame_reducer_onnx(
|
||||
frame_reducer: nn.Module,
|
||||
frame_reducer_filename: str,
|
||||
@ -623,32 +833,54 @@ def main():
|
||||
)
|
||||
|
||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
if params.onnx is True:
|
||||
export_decoder_model_onnx(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
elif params.onnx_triton is True:
|
||||
export_decoder_model_onnx_triton(
|
||||
model.decoder,
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
if params.onnx is True:
|
||||
export_joiner_model_onnx(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
elif params.onnx_triton is True:
|
||||
export_joiner_model_onnx_triton(
|
||||
model.joiner,
|
||||
joiner_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
lconv_filename = params.exp_dir / "lconv.onnx"
|
||||
export_lconv_onnx(
|
||||
model.lconv,
|
||||
lconv_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
if params.onnx is True:
|
||||
export_lconv_onnx(
|
||||
model.lconv,
|
||||
lconv_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
elif params.onnx_triton is True:
|
||||
export_lconv_onnx_triton(
|
||||
model.lconv,
|
||||
lconv_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
frame_reducer_filename = params.exp_dir / "frame_reducer.onnx"
|
||||
export_frame_reducer_onnx(
|
||||
model.frame_reducer,
|
||||
frame_reducer_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
if params.onnx is True:
|
||||
frame_reducer_filename = params.exp_dir / "frame_reducer.onnx"
|
||||
export_frame_reducer_onnx(
|
||||
model.frame_reducer,
|
||||
frame_reducer_filename,
|
||||
opset_version=opset_version,
|
||||
)
|
||||
|
||||
ctc_output_filename = params.exp_dir / "ctc_output.onnx"
|
||||
export_ctc_output_onnx(
|
||||
|
||||
98
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py
Executable file
98
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py
Executable file
@ -0,0 +1,98 @@
|
||||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class TritonOnnxDecoder(nn.Module):
|
||||
"""
|
||||
Triton wrapper for decoder model
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
"""
|
||||
Args:
|
||||
model: decoder model
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
|
||||
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
need_pad = False
|
||||
return self.model(y, need_pad)
|
||||
|
||||
|
||||
class TritonOnnxJoiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.encoder_proj = model.encoder_proj
|
||||
self.decoder_proj = model.decoder_proj
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, C).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, C).
|
||||
"""
|
||||
project_input = False
|
||||
return self.model(encoder_out, decoder_out, project_input)
|
||||
|
||||
|
||||
class TritonOnnxLconv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lconv_input: torch.Tensor,
|
||||
lconv_input_lens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
lconv_input: Its shape is (N, T, C).
|
||||
lconv_input_lens: Its shape is (N, ).
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, C).
|
||||
"""
|
||||
mask = make_pad_mask(lconv_input_lens)
|
||||
|
||||
return self.model(x=lconv_input, src_key_padding_mask=mask)
|
||||
Loading…
x
Reference in New Issue
Block a user