Support exporting BS Zipformer models to ONNX, used in Triton Server (#1008)

* Support export BS Zipformer models to ONNX in Tritron

* Update copyright

* Update exporting codes for BS zipformer models

* Code format

* Update comments

* Update export_onnx.py

---------

Co-authored-by: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
This commit is contained in:
Wen Ding 2023-04-18 17:05:08 +08:00 committed by GitHub
parent 05e7435d0d
commit 78b9dcc936
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 354 additions and 24 deletions

View File

@ -2,6 +2,7 @@
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
# Yifan Yang)
# 2023 NVIDIA Corporation (Author: Wen Ding)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -29,7 +30,8 @@ Usage:
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 13
--avg 13 \
--onnx 1
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
@ -41,6 +43,25 @@ Check `onnx_check.py` for how to use them.
- joiner_decoder_proj.onnx
- lconv.onnx
- frame_reducer.onnx
- ctc_output.onnx
(2) Export to ONNX format which can be used in Triton Server
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 13 \
--onnx-triton 1
It will generate the following files in the given `exp_dir`.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- lconv.onnx
- ctc_output.onnx
Please see ./onnx_pretrained.py for usage of the generated files
@ -78,6 +99,7 @@ from icefall.checkpoint import (
load_checkpoint,
)
from icefall.utils import str2bool
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
def get_parser():
@ -143,9 +165,10 @@ def get_parser():
parser.add_argument(
"--onnx",
type=str2bool,
default=True,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. It will generate the following files:
to onnx format.
It will generate the following files:
- encoder.onnx
- decoder.onnx
@ -154,10 +177,28 @@ def get_parser():
- joiner_decoder_proj.onnx
- lconv.onnx
- frame_reducer.onnx
- ctc_output.onnx
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--onnx-triton",
type=str2bool,
default=False,
help="""If True, and it exports the model
to onnx format which can be used in NVIDIA triton server.
It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
- lconv.onnx
- ctc_output.onnx
""",
)
parser.add_argument(
"--context-size",
@ -273,6 +314,44 @@ def export_decoder_model_onnx(
logging.info(f"Saved to {decoder_filename}")
def export_decoder_model_onnx_triton(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX-Triton format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
decoder_model = TritonOnnxDecoder(decoder_model)
torch.onnx.export(
decoder_model,
(y),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
@ -369,6 +448,91 @@ def export_joiner_model_onnx(
logging.info(f"Saved to {decoder_proj_filename}")
def export_joiner_model_onnx_triton(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
# Note: It uses torch.jit.trace() internally
joiner_model = TritonOnnxJoiner(joiner_model)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_proj_filename}")
def export_lconv_onnx(
lconv: nn.Module,
lconv_filename: str,
@ -413,6 +577,52 @@ def export_lconv_onnx(
logging.info(f"Saved to {lconv_filename}")
def export_lconv_onnx_triton(
lconv: nn.Module,
lconv_filename: str,
opset_version: int = 11,
) -> None:
"""Export the lconv to ONNX format.
The exported lconv has two inputs:
- lconv_input: a tensor of shape (N, T, C)
- lconv_input_lens: a tensor of shape (N, )
and has one output:
- lconv_out: a tensor of shape (N, T, C)
Args:
lconv:
The lconv to be exported.
lconv_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32)
lconv_input_lens = torch.tensor([498] * 15, dtype=torch.int64)
lconv = TritonOnnxLconv(lconv)
torch.onnx.export(
lconv,
(lconv_input, lconv_input_lens),
lconv_filename,
verbose=False,
opset_version=opset_version,
input_names=["lconv_input", "lconv_input_lens"],
output_names=["lconv_out"],
dynamic_axes={
"lconv_input": {0: "N", 1: "T"},
"lconv_input_lens": {0: "N"},
"lconv_out": {0: "N", 1: "T"},
},
)
logging.info(f"Saved to {lconv_filename}")
def export_frame_reducer_onnx(
frame_reducer: nn.Module,
frame_reducer_filename: str,
@ -623,32 +833,54 @@ def main():
)
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
if params.onnx is True:
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
elif params.onnx_triton is True:
export_decoder_model_onnx_triton(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
if params.onnx is True:
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
elif params.onnx_triton is True:
export_joiner_model_onnx_triton(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
lconv_filename = params.exp_dir / "lconv.onnx"
export_lconv_onnx(
model.lconv,
lconv_filename,
opset_version=opset_version,
)
if params.onnx is True:
export_lconv_onnx(
model.lconv,
lconv_filename,
opset_version=opset_version,
)
elif params.onnx_triton is True:
export_lconv_onnx_triton(
model.lconv,
lconv_filename,
opset_version=opset_version,
)
frame_reducer_filename = params.exp_dir / "frame_reducer.onnx"
export_frame_reducer_onnx(
model.frame_reducer,
frame_reducer_filename,
opset_version=opset_version,
)
if params.onnx is True:
frame_reducer_filename = params.exp_dir / "frame_reducer.onnx"
export_frame_reducer_onnx(
model.frame_reducer,
frame_reducer_filename,
opset_version=opset_version,
)
ctc_output_filename = params.exp_dir / "ctc_output.onnx"
export_ctc_output_onnx(

View File

@ -0,0 +1,98 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from icefall.utils import make_pad_mask
class TritonOnnxDecoder(nn.Module):
"""
Triton wrapper for decoder model
"""
def __init__(self, model):
"""
Args:
model: decoder model
"""
super().__init__()
self.model = model
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
need_pad = False
return self.model(y, need_pad)
class TritonOnnxJoiner(nn.Module):
def __init__(
self,
model,
):
super().__init__()
self.model = model
self.encoder_proj = model.encoder_proj
self.decoder_proj = model.decoder_proj
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
decoder_out:
Output from the decoder. Its shape is (N, T, C).
Returns:
Return a tensor of shape (N, T, C).
"""
project_input = False
return self.model(encoder_out, decoder_out, project_input)
class TritonOnnxLconv(nn.Module):
def __init__(
self,
model,
):
super().__init__()
self.model = model
def forward(
self,
lconv_input: torch.Tensor,
lconv_input_lens: torch.Tensor,
) -> torch.Tensor:
"""
Args:
lconv_input: Its shape is (N, T, C).
lconv_input_lens: Its shape is (N, ).
Returns:
Return a tensor of shape (N, T, C).
"""
mask = make_pad_mask(lconv_input_lens)
return self.model(x=lconv_input, src_key_padding_mask=mask)