mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-10 22:45:27 +00:00
Support exporting BS Zipformer models to ONNX, used in Triton Server (#1008)
* Support export BS Zipformer models to ONNX in Tritron * Update copyright * Update exporting codes for BS zipformer models * Code format * Update comments * Update export_onnx.py --------- Co-authored-by: Yifan Yang <64255737+yfyeung@users.noreply.github.com>
This commit is contained in:
parent
05e7435d0d
commit
78b9dcc936
@ -2,6 +2,7 @@
|
|||||||
#
|
#
|
||||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
|
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
# Yifan Yang)
|
# Yifan Yang)
|
||||||
|
# 2023 NVIDIA Corporation (Author: Wen Ding)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -29,7 +30,8 @@ Usage:
|
|||||||
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 13
|
--avg 13 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
It will generate the following files in the given `exp_dir`.
|
It will generate the following files in the given `exp_dir`.
|
||||||
Check `onnx_check.py` for how to use them.
|
Check `onnx_check.py` for how to use them.
|
||||||
@ -41,6 +43,25 @@ Check `onnx_check.py` for how to use them.
|
|||||||
- joiner_decoder_proj.onnx
|
- joiner_decoder_proj.onnx
|
||||||
- lconv.onnx
|
- lconv.onnx
|
||||||
- frame_reducer.onnx
|
- frame_reducer.onnx
|
||||||
|
- ctc_output.onnx
|
||||||
|
|
||||||
|
(2) Export to ONNX format which can be used in Triton Server
|
||||||
|
./pruned_transducer_stateless7_ctc_bs/export_onnx.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 13 \
|
||||||
|
--onnx-triton 1
|
||||||
|
|
||||||
|
It will generate the following files in the given `exp_dir`.
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
- joiner_encoder_proj.onnx
|
||||||
|
- joiner_decoder_proj.onnx
|
||||||
|
- lconv.onnx
|
||||||
|
- ctc_output.onnx
|
||||||
|
|
||||||
Please see ./onnx_pretrained.py for usage of the generated files
|
Please see ./onnx_pretrained.py for usage of the generated files
|
||||||
|
|
||||||
@ -78,6 +99,7 @@ from icefall.checkpoint import (
|
|||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -143,9 +165,10 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--onnx",
|
"--onnx",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=False,
|
||||||
help="""If True, --jit is ignored and it exports the model
|
help="""If True, --jit is ignored and it exports the model
|
||||||
to onnx format. It will generate the following files:
|
to onnx format.
|
||||||
|
It will generate the following files:
|
||||||
|
|
||||||
- encoder.onnx
|
- encoder.onnx
|
||||||
- decoder.onnx
|
- decoder.onnx
|
||||||
@ -154,10 +177,28 @@ def get_parser():
|
|||||||
- joiner_decoder_proj.onnx
|
- joiner_decoder_proj.onnx
|
||||||
- lconv.onnx
|
- lconv.onnx
|
||||||
- frame_reducer.onnx
|
- frame_reducer.onnx
|
||||||
|
- ctc_output.onnx
|
||||||
|
|
||||||
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-triton",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""If True, and it exports the model
|
||||||
|
to onnx format which can be used in NVIDIA triton server.
|
||||||
|
It will generate the following files:
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
- joiner_encoder_proj.onnx
|
||||||
|
- joiner_decoder_proj.onnx
|
||||||
|
- lconv.onnx
|
||||||
|
- ctc_output.onnx
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
@ -273,6 +314,44 @@ def export_decoder_model_onnx(
|
|||||||
logging.info(f"Saved to {decoder_filename}")
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx_triton(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX-Triton format.
|
||||||
|
The exported model has one input:
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
and has one output:
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
|
||||||
|
decoder_model = TritonOnnxDecoder(decoder_model)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
(y),
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
def export_joiner_model_onnx(
|
def export_joiner_model_onnx(
|
||||||
joiner_model: nn.Module,
|
joiner_model: nn.Module,
|
||||||
joiner_filename: str,
|
joiner_filename: str,
|
||||||
@ -369,6 +448,91 @@ def export_joiner_model_onnx(
|
|||||||
logging.info(f"Saved to {decoder_proj_filename}")
|
logging.info(f"Saved to {decoder_proj_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx_triton(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
and produces one output:
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
The exported encoder_proj model has one input:
|
||||||
|
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||||
|
and produces one output:
|
||||||
|
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
The exported decoder_proj model has one input:
|
||||||
|
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||||
|
and produces one output:
|
||||||
|
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||||
|
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||||
|
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Note: It uses torch.jit.trace() internally
|
||||||
|
joiner_model = TritonOnnxJoiner(joiner_model)
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
"project_input",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model.encoder_proj,
|
||||||
|
encoder_out,
|
||||||
|
encoder_proj_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["encoder_out"],
|
||||||
|
output_names=["projected_encoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"projected_encoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {encoder_proj_filename}")
|
||||||
|
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model.decoder_proj,
|
||||||
|
decoder_out,
|
||||||
|
decoder_proj_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["decoder_out"],
|
||||||
|
output_names=["projected_decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"projected_decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {decoder_proj_filename}")
|
||||||
|
|
||||||
|
|
||||||
def export_lconv_onnx(
|
def export_lconv_onnx(
|
||||||
lconv: nn.Module,
|
lconv: nn.Module,
|
||||||
lconv_filename: str,
|
lconv_filename: str,
|
||||||
@ -413,6 +577,52 @@ def export_lconv_onnx(
|
|||||||
logging.info(f"Saved to {lconv_filename}")
|
logging.info(f"Saved to {lconv_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_lconv_onnx_triton(
|
||||||
|
lconv: nn.Module,
|
||||||
|
lconv_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the lconv to ONNX format.
|
||||||
|
|
||||||
|
The exported lconv has two inputs:
|
||||||
|
|
||||||
|
- lconv_input: a tensor of shape (N, T, C)
|
||||||
|
- lconv_input_lens: a tensor of shape (N, )
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- lconv_out: a tensor of shape (N, T, C)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lconv:
|
||||||
|
The lconv to be exported.
|
||||||
|
lconv_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32)
|
||||||
|
lconv_input_lens = torch.tensor([498] * 15, dtype=torch.int64)
|
||||||
|
|
||||||
|
lconv = TritonOnnxLconv(lconv)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
lconv,
|
||||||
|
(lconv_input, lconv_input_lens),
|
||||||
|
lconv_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["lconv_input", "lconv_input_lens"],
|
||||||
|
output_names=["lconv_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"lconv_input": {0: "N", 1: "T"},
|
||||||
|
"lconv_input_lens": {0: "N"},
|
||||||
|
"lconv_out": {0: "N", 1: "T"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {lconv_filename}")
|
||||||
|
|
||||||
|
|
||||||
def export_frame_reducer_onnx(
|
def export_frame_reducer_onnx(
|
||||||
frame_reducer: nn.Module,
|
frame_reducer: nn.Module,
|
||||||
frame_reducer_filename: str,
|
frame_reducer_filename: str,
|
||||||
@ -623,32 +833,54 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
decoder_filename = params.exp_dir / "decoder.onnx"
|
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||||
export_decoder_model_onnx(
|
if params.onnx is True:
|
||||||
model.decoder,
|
export_decoder_model_onnx(
|
||||||
decoder_filename,
|
model.decoder,
|
||||||
opset_version=opset_version,
|
decoder_filename,
|
||||||
)
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
elif params.onnx_triton is True:
|
||||||
|
export_decoder_model_onnx_triton(
|
||||||
|
model.decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
joiner_filename = params.exp_dir / "joiner.onnx"
|
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||||
export_joiner_model_onnx(
|
if params.onnx is True:
|
||||||
model.joiner,
|
export_joiner_model_onnx(
|
||||||
joiner_filename,
|
model.joiner,
|
||||||
opset_version=opset_version,
|
joiner_filename,
|
||||||
)
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
elif params.onnx_triton is True:
|
||||||
|
export_joiner_model_onnx_triton(
|
||||||
|
model.joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
lconv_filename = params.exp_dir / "lconv.onnx"
|
lconv_filename = params.exp_dir / "lconv.onnx"
|
||||||
export_lconv_onnx(
|
if params.onnx is True:
|
||||||
model.lconv,
|
export_lconv_onnx(
|
||||||
lconv_filename,
|
model.lconv,
|
||||||
opset_version=opset_version,
|
lconv_filename,
|
||||||
)
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
elif params.onnx_triton is True:
|
||||||
|
export_lconv_onnx_triton(
|
||||||
|
model.lconv,
|
||||||
|
lconv_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
frame_reducer_filename = params.exp_dir / "frame_reducer.onnx"
|
if params.onnx is True:
|
||||||
export_frame_reducer_onnx(
|
frame_reducer_filename = params.exp_dir / "frame_reducer.onnx"
|
||||||
model.frame_reducer,
|
export_frame_reducer_onnx(
|
||||||
frame_reducer_filename,
|
model.frame_reducer,
|
||||||
opset_version=opset_version,
|
frame_reducer_filename,
|
||||||
)
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
ctc_output_filename = params.exp_dir / "ctc_output.onnx"
|
ctc_output_filename = params.exp_dir / "ctc_output.onnx"
|
||||||
export_ctc_output_onnx(
|
export_ctc_output_onnx(
|
||||||
|
|||||||
98
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py
Executable file
98
egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py
Executable file
@ -0,0 +1,98 @@
|
|||||||
|
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class TritonOnnxDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Triton wrapper for decoder model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model: decoder model
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, U).
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
return self.model(y, need_pad)
|
||||||
|
|
||||||
|
|
||||||
|
class TritonOnnxJoiner(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.encoder_proj = model.encoder_proj
|
||||||
|
self.decoder_proj = model.decoder_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
decoder_out:
|
||||||
|
Output from the decoder. Its shape is (N, T, C).
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, T, C).
|
||||||
|
"""
|
||||||
|
project_input = False
|
||||||
|
return self.model(encoder_out, decoder_out, project_input)
|
||||||
|
|
||||||
|
|
||||||
|
class TritonOnnxLconv(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
lconv_input: torch.Tensor,
|
||||||
|
lconv_input_lens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lconv_input: Its shape is (N, T, C).
|
||||||
|
lconv_input_lens: Its shape is (N, ).
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, T, C).
|
||||||
|
"""
|
||||||
|
mask = make_pad_mask(lconv_input_lens)
|
||||||
|
|
||||||
|
return self.model(x=lconv_input, src_key_padding_mask=mask)
|
||||||
Loading…
x
Reference in New Issue
Block a user