diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py index 50efa6e60..630a7f735 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/export_onnx.py @@ -2,6 +2,7 @@ # # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang, # Yifan Yang) +# 2023 NVIDIA Corporation (Author: Wen Ding) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -29,7 +30,8 @@ Usage: --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ --bpe-model data/lang_bpe_500/bpe.model \ --epoch 30 \ - --avg 13 + --avg 13 \ + --onnx 1 It will generate the following files in the given `exp_dir`. Check `onnx_check.py` for how to use them. @@ -41,6 +43,25 @@ Check `onnx_check.py` for how to use them. - joiner_decoder_proj.onnx - lconv.onnx - frame_reducer.onnx + - ctc_output.onnx + +(2) Export to ONNX format which can be used in Triton Server +./pruned_transducer_stateless7_ctc_bs/export_onnx.py \ + --exp-dir ./pruned_transducer_stateless7_ctc_bs/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 13 \ + --onnx-triton 1 + +It will generate the following files in the given `exp_dir`. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - ctc_output.onnx Please see ./onnx_pretrained.py for usage of the generated files @@ -78,6 +99,7 @@ from icefall.checkpoint import ( load_checkpoint, ) from icefall.utils import str2bool +from onnx_wrapper import TritonOnnxDecoder, TritonOnnxJoiner, TritonOnnxLconv def get_parser(): @@ -143,9 +165,10 @@ def get_parser(): parser.add_argument( "--onnx", type=str2bool, - default=True, + default=False, help="""If True, --jit is ignored and it exports the model - to onnx format. It will generate the following files: + to onnx format. + It will generate the following files: - encoder.onnx - decoder.onnx @@ -154,10 +177,28 @@ def get_parser(): - joiner_decoder_proj.onnx - lconv.onnx - frame_reducer.onnx + - ctc_output.onnx Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. """, ) + parser.add_argument( + "--onnx-triton", + type=str2bool, + default=False, + help="""If True, and it exports the model + to onnx format which can be used in NVIDIA triton server. + It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + - lconv.onnx + - ctc_output.onnx + """, + ) parser.add_argument( "--context-size", @@ -273,6 +314,44 @@ def export_decoder_model_onnx( logging.info(f"Saved to {decoder_filename}") +def export_decoder_model_onnx_triton( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX-Triton format. + The exported model has one input: + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + and has one output: + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + + decoder_model = TritonOnnxDecoder(decoder_model) + + torch.onnx.export( + decoder_model, + (y), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + def export_joiner_model_onnx( joiner_model: nn.Module, joiner_filename: str, @@ -369,6 +448,91 @@ def export_joiner_model_onnx( logging.info(f"Saved to {decoder_proj_filename}") +def export_joiner_model_onnx_triton( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + and produces one output: + - logit: a tensor of shape (N, vocab_size) + The exported encoder_proj model has one input: + - encoder_out: a tensor of shape (N, encoder_out_dim) + and produces one output: + - projected_encoder_out: a tensor of shape (N, joiner_dim) + The exported decoder_proj model has one input: + - decoder_out: a tensor of shape (N, decoder_out_dim) + and produces one output: + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + # Note: It uses torch.jit.trace() internally + joiner_model = TritonOnnxJoiner(joiner_model) + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + def export_lconv_onnx( lconv: nn.Module, lconv_filename: str, @@ -413,6 +577,52 @@ def export_lconv_onnx( logging.info(f"Saved to {lconv_filename}") +def export_lconv_onnx_triton( + lconv: nn.Module, + lconv_filename: str, + opset_version: int = 11, +) -> None: + """Export the lconv to ONNX format. + + The exported lconv has two inputs: + + - lconv_input: a tensor of shape (N, T, C) + - lconv_input_lens: a tensor of shape (N, ) + + and has one output: + + - lconv_out: a tensor of shape (N, T, C) + + Args: + lconv: + The lconv to be exported. + lconv_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + lconv_input = torch.zeros(15, 498, 384, dtype=torch.float32) + lconv_input_lens = torch.tensor([498] * 15, dtype=torch.int64) + + lconv = TritonOnnxLconv(lconv) + + torch.onnx.export( + lconv, + (lconv_input, lconv_input_lens), + lconv_filename, + verbose=False, + opset_version=opset_version, + input_names=["lconv_input", "lconv_input_lens"], + output_names=["lconv_out"], + dynamic_axes={ + "lconv_input": {0: "N", 1: "T"}, + "lconv_input_lens": {0: "N"}, + "lconv_out": {0: "N", 1: "T"}, + }, + ) + logging.info(f"Saved to {lconv_filename}") + + def export_frame_reducer_onnx( frame_reducer: nn.Module, frame_reducer_filename: str, @@ -623,32 +833,54 @@ def main(): ) decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) + if params.onnx is True: + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + elif params.onnx_triton is True: + export_decoder_model_onnx_triton( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) + if params.onnx is True: + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + elif params.onnx_triton is True: + export_joiner_model_onnx_triton( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) lconv_filename = params.exp_dir / "lconv.onnx" - export_lconv_onnx( - model.lconv, - lconv_filename, - opset_version=opset_version, - ) + if params.onnx is True: + export_lconv_onnx( + model.lconv, + lconv_filename, + opset_version=opset_version, + ) + elif params.onnx_triton is True: + export_lconv_onnx_triton( + model.lconv, + lconv_filename, + opset_version=opset_version, + ) - frame_reducer_filename = params.exp_dir / "frame_reducer.onnx" - export_frame_reducer_onnx( - model.frame_reducer, - frame_reducer_filename, - opset_version=opset_version, - ) + if params.onnx is True: + frame_reducer_filename = params.exp_dir / "frame_reducer.onnx" + export_frame_reducer_onnx( + model.frame_reducer, + frame_reducer_filename, + opset_version=opset_version, + ) ctc_output_filename = params.exp_dir / "ctc_output.onnx" export_ctc_output_onnx( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py new file mode 100755 index 000000000..247da0949 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc_bs/onnx_wrapper.py @@ -0,0 +1,98 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from icefall.utils import make_pad_mask + + +class TritonOnnxDecoder(nn.Module): + """ + Triton wrapper for decoder model + """ + + def __init__(self, model): + """ + Args: + model: decoder model + """ + super().__init__() + + self.model = model + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + need_pad = False + return self.model(y, need_pad) + + +class TritonOnnxJoiner(nn.Module): + def __init__( + self, + model, + ): + super().__init__() + + self.model = model + self.encoder_proj = model.encoder_proj + self.decoder_proj = model.decoder_proj + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, C). + decoder_out: + Output from the decoder. Its shape is (N, T, C). + Returns: + Return a tensor of shape (N, T, C). + """ + project_input = False + return self.model(encoder_out, decoder_out, project_input) + + +class TritonOnnxLconv(nn.Module): + def __init__( + self, + model, + ): + super().__init__() + + self.model = model + + def forward( + self, + lconv_input: torch.Tensor, + lconv_input_lens: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + lconv_input: Its shape is (N, T, C). + lconv_input_lens: Its shape is (N, ). + Returns: + Return a tensor of shape (N, T, C). + """ + mask = make_pad_mask(lconv_input_lens) + + return self.model(x=lconv_input, src_key_padding_mask=mask)