mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Add streaming onnx export for zipformer (#831)
* add streaming onnx export for zipformer * update triton support * add comments * add ci test * add onnxmltools for fp16 onnx export
This commit is contained in:
parent
029c8566e4
commit
bf5f0342a2
@ -33,6 +33,16 @@ ln -s pretrained.pt epoch-99.pt
|
|||||||
ls -lh *.pt
|
ls -lh *.pt
|
||||||
popd
|
popd
|
||||||
|
|
||||||
|
log "Test exporting to ONNX format"
|
||||||
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--use-averaged-model false \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--fp16 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
log "Export to torchscript model"
|
log "Export to torchscript model"
|
||||||
./pruned_transducer_stateless7_streaming/export.py \
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
|
@ -39,7 +39,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
run_librispeech_2022_12_29_zipformer_streaming:
|
run_librispeech_2022_12_29_zipformer_streaming:
|
||||||
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
|
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
@ -72,25 +72,81 @@ Check ./pretrained.py for its usage.
|
|||||||
Note: If you don't want to train a model from scratch, we have
|
Note: If you don't want to train a model from scratch, we have
|
||||||
provided one for you. You can get it at
|
provided one for you. You can get it at
|
||||||
|
|
||||||
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
|
|
||||||
with the following commands:
|
with the following commands:
|
||||||
|
|
||||||
sudo apt-get install git-lfs
|
sudo apt-get install git-lfs
|
||||||
git lfs install
|
git lfs install
|
||||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
|
||||||
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
|
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
|
||||||
|
|
||||||
|
(3) Export to ONNX format with pretrained.pt
|
||||||
|
|
||||||
|
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||||
|
ln -s pretrained.pt epoch-999.pt
|
||||||
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
|
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--fp16 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
It will generate the following files in the given `exp_dir`.
|
||||||
|
Check `onnx_check.py` for how to use them.
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
- joiner_encoder_proj.onnx
|
||||||
|
- joiner_decoder_proj.onnx
|
||||||
|
|
||||||
|
Check
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
|
(4) Export to ONNX format for triton server
|
||||||
|
|
||||||
|
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
|
||||||
|
ln -s pretrained.pt epoch-999.pt
|
||||||
|
./pruned_transducer_stateless7_streaming/export.py \
|
||||||
|
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model False \
|
||||||
|
--epoch 999 \
|
||||||
|
--avg 1 \
|
||||||
|
--fp16 \
|
||||||
|
--onnx-triton 1 \
|
||||||
|
--onnx 1
|
||||||
|
|
||||||
|
It will generate the following files in the given `exp_dir`.
|
||||||
|
Check `onnx_check.py` for how to use them.
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
|
||||||
|
Check
|
||||||
|
https://github.com/k2-fsa/sherpa/tree/master/triton
|
||||||
|
for how to use the exported models outside of icefall.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import onnxruntime
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
from zipformer import stack_states
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -172,6 +228,42 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""If True, --jit is ignored and it exports the model
|
||||||
|
to onnx format. It will generate the following files:
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
- joiner_encoder_proj.onnx
|
||||||
|
- joiner_decoder_proj.onnx
|
||||||
|
|
||||||
|
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--onnx-triton",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""If True, --onnx would export model into the following files:
|
||||||
|
|
||||||
|
- encoder.onnx
|
||||||
|
- decoder.onnx
|
||||||
|
- joiner.onnx
|
||||||
|
These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--fp16",
|
||||||
|
action="store_true",
|
||||||
|
help="whether to export fp16 onnx model, default false",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -184,6 +276,391 @@ def get_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
|
||||||
|
for a, b in zip(xlist, blist):
|
||||||
|
try:
|
||||||
|
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
||||||
|
except AssertionError as error:
|
||||||
|
if tolerate_small_mismatch:
|
||||||
|
print("small mismatch detected", error)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: nn.Module,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T, C)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
|
|
||||||
|
Note: The warmup argument is fixed to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
batch_size = 17
|
||||||
|
seq_len = 101
|
||||||
|
torch.manual_seed(0)
|
||||||
|
x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64)
|
||||||
|
|
||||||
|
# encoder_model = torch.jit.script(encoder_model)
|
||||||
|
# It throws the following error for the above statement
|
||||||
|
#
|
||||||
|
# RuntimeError: Exporting the operator __is_ to ONNX opset version
|
||||||
|
# 11 is not supported. Please feel free to request support or
|
||||||
|
# submit a pull request on PyTorch GitHub.
|
||||||
|
#
|
||||||
|
# I cannot find which statement causes the above error.
|
||||||
|
# torch.onnx.export() will use torch.jit.trace() internally, which
|
||||||
|
# works well for the current reworked model
|
||||||
|
initial_states = [encoder_model.get_init_state() for _ in range(batch_size)]
|
||||||
|
states = stack_states(initial_states)
|
||||||
|
|
||||||
|
left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks
|
||||||
|
encoder_attention_dim = encoder_model.encoders[0].attention_dim
|
||||||
|
|
||||||
|
len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15
|
||||||
|
avg_cache = torch.cat(
|
||||||
|
states[encoder_model.num_encoders : 2 * encoder_model.num_encoders]
|
||||||
|
).transpose(
|
||||||
|
0, 1
|
||||||
|
) # [B,15,384]
|
||||||
|
cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose(
|
||||||
|
0, 1
|
||||||
|
) # [B,2*15,384,cnn_kernel-1]
|
||||||
|
pad_tensors = [
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
tensor,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
encoder_attention_dim - tensor.shape[-1],
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
left_context_len - tensor.shape[1],
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tensor in states[
|
||||||
|
2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders
|
||||||
|
]
|
||||||
|
]
|
||||||
|
attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
|
||||||
|
|
||||||
|
encoder_model_wrapper = OnnxStreamingEncoder(encoder_model)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model_wrapper,
|
||||||
|
(x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"x",
|
||||||
|
"x_lens",
|
||||||
|
"len_cache",
|
||||||
|
"avg_cache",
|
||||||
|
"attn_cache",
|
||||||
|
"cnn_cache",
|
||||||
|
],
|
||||||
|
output_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"encoder_out_lens",
|
||||||
|
"new_len_cache",
|
||||||
|
"new_avg_cache",
|
||||||
|
"new_attn_cache",
|
||||||
|
"new_cnn_cache",
|
||||||
|
],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"x_lens": {0: "N"},
|
||||||
|
"encoder_out": {0: "N", 1: "T"},
|
||||||
|
"encoder_out_lens": {0: "N"},
|
||||||
|
"len_cache": {0: "N"},
|
||||||
|
"avg_cache": {0: "N"},
|
||||||
|
"attn_cache": {0: "N"},
|
||||||
|
"cnn_cache": {0: "N"},
|
||||||
|
"new_len_cache": {0: "N"},
|
||||||
|
"new_avg_cache": {0: "N"},
|
||||||
|
"new_attn_cache": {0: "N"},
|
||||||
|
"new_cnn_cache": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {encoder_filename}")
|
||||||
|
|
||||||
|
# Test onnx encoder with torch native encoder
|
||||||
|
encoder_model.eval()
|
||||||
|
(
|
||||||
|
encoder_out_torch,
|
||||||
|
encoder_out_lens_torch,
|
||||||
|
new_states_torch,
|
||||||
|
) = encoder_model.streaming_forward(
|
||||||
|
x=x,
|
||||||
|
x_lens=x_lens,
|
||||||
|
states=states,
|
||||||
|
)
|
||||||
|
ort_session = onnxruntime.InferenceSession(
|
||||||
|
str(encoder_filename), providers=["CPUExecutionProvider"]
|
||||||
|
)
|
||||||
|
ort_inputs = {
|
||||||
|
"x": x.numpy(),
|
||||||
|
"x_lens": x_lens.numpy(),
|
||||||
|
"len_cache": len_cache.numpy(),
|
||||||
|
"avg_cache": avg_cache.numpy(),
|
||||||
|
"attn_cache": attn_cache.numpy(),
|
||||||
|
"cnn_cache": cnn_cache.numpy(),
|
||||||
|
}
|
||||||
|
ort_outs = ort_session.run(None, ort_inputs)
|
||||||
|
|
||||||
|
assert test_acc(
|
||||||
|
[encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2]
|
||||||
|
)
|
||||||
|
logging.info(f"{encoder_filename} acc test succeeded.")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
need_pad = False # Always False, so we can use torch.jit.trace() here
|
||||||
|
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
|
||||||
|
# in this case
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
(y, need_pad),
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y", "need_pad"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx_triton(
|
||||||
|
decoder_model: nn.Module,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
|
||||||
|
|
||||||
|
Note: The argument need_pad is fixed to False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
|
||||||
|
|
||||||
|
decoder_model = TritonOnnxDecoder(decoder_model)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
(y,),
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {decoder_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
|
||||||
|
The exported encoder_proj model has one input:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- projected_encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
The exported decoder_proj model has one input:
|
||||||
|
|
||||||
|
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- projected_decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
|
||||||
|
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
|
||||||
|
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
project_input = False
|
||||||
|
# Note: It uses torch.jit.trace() internally
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out, project_input),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
"project_input",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model.encoder_proj,
|
||||||
|
encoder_out,
|
||||||
|
encoder_proj_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["encoder_out"],
|
||||||
|
output_names=["projected_encoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"projected_encoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {encoder_proj_filename}")
|
||||||
|
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model.decoder_proj,
|
||||||
|
decoder_out,
|
||||||
|
decoder_proj_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["decoder_out"],
|
||||||
|
output_names=["projected_decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"projected_decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {decoder_proj_filename}")
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx_triton(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
- encoder_out: a tensor of shape (N, encoder_out_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, decoder_out_dim)
|
||||||
|
and has one output:
|
||||||
|
- joiner_out: a tensor of shape (N, vocab_size)
|
||||||
|
Note: The argument project_input is fixed to True. A user should not
|
||||||
|
project the encoder_out/decoder_out by himself/herself. The exported joiner
|
||||||
|
will do that for the user.
|
||||||
|
"""
|
||||||
|
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
|
||||||
|
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
|
||||||
|
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||||
|
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
joiner_model = TritonOnnxJoiner(joiner_model)
|
||||||
|
# Note: It uses torch.jit.trace() internally
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(encoder_out, decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["encoder_out", "decoder_out"],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logging.info(f"Saved to {joiner_filename}")
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
@ -292,7 +769,87 @@ def main():
|
|||||||
model.to("cpu")
|
model.to("cpu")
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit is True:
|
if params.onnx:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
opset_version = 13
|
||||||
|
logging.info("Exporting to onnx format")
|
||||||
|
encoder_filename = params.exp_dir / "encoder.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
model.encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
if not params.onnx_triton:
|
||||||
|
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
model.decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
model.joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
decoder_filename = params.exp_dir / "decoder.onnx"
|
||||||
|
export_decoder_model_onnx_triton(
|
||||||
|
model.decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename = params.exp_dir / "joiner.onnx"
|
||||||
|
export_joiner_model_onnx_triton(
|
||||||
|
model.joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.fp16:
|
||||||
|
try:
|
||||||
|
import onnxmltools
|
||||||
|
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
||||||
|
except ImportError:
|
||||||
|
print("Please install onnxmltools!")
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
||||||
|
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
||||||
|
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model)
|
||||||
|
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
||||||
|
|
||||||
|
encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx"
|
||||||
|
export_onnx_fp16(encoder_filename, encoder_fp16_filename)
|
||||||
|
|
||||||
|
decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx"
|
||||||
|
export_onnx_fp16(decoder_filename, decoder_fp16_filename)
|
||||||
|
|
||||||
|
joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx"
|
||||||
|
export_onnx_fp16(joiner_filename, joiner_fp16_filename)
|
||||||
|
|
||||||
|
if not params.onnx_triton:
|
||||||
|
encoder_proj_filename = str(joiner_filename).replace(
|
||||||
|
".onnx", "_encoder_proj.onnx"
|
||||||
|
)
|
||||||
|
encoder_proj_fp16_filename = (
|
||||||
|
params.exp_dir / "joiner_encoder_proj_fp16.onnx"
|
||||||
|
)
|
||||||
|
export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename)
|
||||||
|
|
||||||
|
decoder_proj_filename = str(joiner_filename).replace(
|
||||||
|
".onnx", "_decoder_proj.onnx"
|
||||||
|
)
|
||||||
|
decoder_proj_fp16_filename = (
|
||||||
|
params.exp_dir / "joiner_decoder_proj_fp16.onnx"
|
||||||
|
)
|
||||||
|
export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename)
|
||||||
|
|
||||||
|
elif params.jit:
|
||||||
convert_scaled_to_non_scaled(model, inplace=True)
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
# We won't use the forward() method of the model in C++, so just ignore
|
||||||
# it here.
|
# it here.
|
||||||
|
@ -0,0 +1,231 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxStreamingEncoder(torch.nn.Module):
|
||||||
|
"""This class warps the streaming Zipformer to reduce the number of
|
||||||
|
state tensors for onnx.
|
||||||
|
https://github.com/k2-fsa/icefall/pull/831
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, encoder):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder: A Instance of Zipformer Class
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model = encoder
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
len_cache: torch.tensor,
|
||||||
|
avg_cache: torch.tensor,
|
||||||
|
attn_cache: torch.tensor,
|
||||||
|
cnn_cache: torch.tensor,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||||
|
x_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`x` before padding.
|
||||||
|
len_cache:
|
||||||
|
The cached numbers of past frames.
|
||||||
|
avg_cache:
|
||||||
|
The cached average tensors.
|
||||||
|
attn_cache:
|
||||||
|
The cached key tensors of the first attention modules.
|
||||||
|
The cached value tensors of the first attention modules.
|
||||||
|
The cached value tensors of the second attention modules.
|
||||||
|
cnn_cache:
|
||||||
|
The cached left contexts of the first convolution modules.
|
||||||
|
The cached left contexts of the second convolution modules.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing 2 tensors:
|
||||||
|
|
||||||
|
"""
|
||||||
|
num_encoder_layers = []
|
||||||
|
encoder_attention_dims = []
|
||||||
|
states = []
|
||||||
|
for i, encoder in enumerate(self.model.encoders):
|
||||||
|
num_encoder_layers.append(encoder.num_layers)
|
||||||
|
encoder_attention_dims.append(encoder.attention_dim)
|
||||||
|
|
||||||
|
len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B]
|
||||||
|
offset = 0
|
||||||
|
for num_layer in num_encoder_layers:
|
||||||
|
states.append(len_cache[offset : offset + num_layer])
|
||||||
|
offset += num_layer
|
||||||
|
|
||||||
|
avg_cache = avg_cache.transpose(0, 1) # [15, B, 384]
|
||||||
|
offset = 0
|
||||||
|
for num_layer in num_encoder_layers:
|
||||||
|
states.append(avg_cache[offset : offset + num_layer])
|
||||||
|
offset += num_layer
|
||||||
|
|
||||||
|
attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192]
|
||||||
|
left_context_len = attn_cache.shape[1]
|
||||||
|
offset = 0
|
||||||
|
for i, num_layer in enumerate(num_encoder_layers):
|
||||||
|
ds = self.model.zipformer_downsampling_factors[i]
|
||||||
|
states.append(
|
||||||
|
attn_cache[offset : offset + num_layer, : left_context_len // ds]
|
||||||
|
)
|
||||||
|
offset += num_layer
|
||||||
|
for i, num_layer in enumerate(num_encoder_layers):
|
||||||
|
encoder_attention_dim = encoder_attention_dims[i]
|
||||||
|
ds = self.model.zipformer_downsampling_factors[i]
|
||||||
|
states.append(
|
||||||
|
attn_cache[
|
||||||
|
offset : offset + num_layer,
|
||||||
|
: left_context_len // ds,
|
||||||
|
:,
|
||||||
|
: encoder_attention_dim // 2,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
offset += num_layer
|
||||||
|
for i, num_layer in enumerate(num_encoder_layers):
|
||||||
|
ds = self.model.zipformer_downsampling_factors[i]
|
||||||
|
states.append(
|
||||||
|
attn_cache[
|
||||||
|
offset : offset + num_layer,
|
||||||
|
: left_context_len // ds,
|
||||||
|
:,
|
||||||
|
: encoder_attention_dim // 2,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
offset += num_layer
|
||||||
|
|
||||||
|
cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1]
|
||||||
|
offset = 0
|
||||||
|
for num_layer in num_encoder_layers:
|
||||||
|
states.append(cnn_cache[offset : offset + num_layer])
|
||||||
|
offset += num_layer
|
||||||
|
for num_layer in num_encoder_layers:
|
||||||
|
states.append(cnn_cache[offset : offset + num_layer])
|
||||||
|
offset += num_layer
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens, new_states = self.model.streaming_forward(
|
||||||
|
x=x,
|
||||||
|
x_lens=x_lens,
|
||||||
|
states=states,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose(
|
||||||
|
0, 1
|
||||||
|
) # [B,15]
|
||||||
|
new_avg_cache = torch.cat(
|
||||||
|
states[self.model.num_encoders : 2 * self.model.num_encoders]
|
||||||
|
).transpose(
|
||||||
|
0, 1
|
||||||
|
) # [B,15,384]
|
||||||
|
new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose(
|
||||||
|
0, 1
|
||||||
|
) # [B,2*15,384,cnn_kernel-1]
|
||||||
|
assert len(set(encoder_attention_dims)) == 1
|
||||||
|
pad_tensors = [
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
tensor,
|
||||||
|
(
|
||||||
|
0,
|
||||||
|
encoder_attention_dims[0] - tensor.shape[-1],
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
left_context_len - tensor.shape[1],
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for tensor in states[
|
||||||
|
2 * self.model.num_encoders : 5 * self.model.num_encoders
|
||||||
|
]
|
||||||
|
]
|
||||||
|
new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
|
||||||
|
|
||||||
|
return (
|
||||||
|
encoder_out,
|
||||||
|
encoder_out_lens,
|
||||||
|
new_len_cache,
|
||||||
|
new_avg_cache,
|
||||||
|
new_attn_cache,
|
||||||
|
new_cnn_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TritonOnnxDecoder(torch.nn.Module):
|
||||||
|
"""This class warps the Decoder in decoder.py
|
||||||
|
to remove the scalar input "need_pad".
|
||||||
|
Triton currently doesn't support scalar input.
|
||||||
|
https://github.com/triton-inference-server/server/issues/2333
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decoder: torch.nn.Module,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
decoder: A instance of Decoder
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.model = decoder
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, U).
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
|
"""
|
||||||
|
# False to not pad the input. Should be False during inference.
|
||||||
|
need_pad = False
|
||||||
|
return self.model(y, need_pad)
|
||||||
|
|
||||||
|
|
||||||
|
class TritonOnnxJoiner(torch.nn.Module):
|
||||||
|
"""This class warps the Joiner in joiner.py
|
||||||
|
to remove the scalar input "project_input".
|
||||||
|
Triton currently doesn't support scalar input.
|
||||||
|
https://github.com/triton-inference-server/server/issues/2333
|
||||||
|
"project_input" is set to True.
|
||||||
|
Triton solutions only need export joiner to a single joiner.onnx.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
joiner: torch.nn.Module,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model = joiner
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||||
|
decoder_out:
|
||||||
|
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||||
|
Returns:
|
||||||
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
|
"""
|
||||||
|
# Apply input projections encoder_proj and decoder_proj.
|
||||||
|
project_input = True
|
||||||
|
return self.model(encoder_out, decoder_out, project_input)
|
@ -2084,6 +2084,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
# the following .as_strided() expression converts the last axis of pos_weights from relative
|
# the following .as_strided() expression converts the last axis of pos_weights from relative
|
||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# not, but let this code define which way round it is supposed to be.
|
# not, but let this code define which way round it is supposed to be.
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
(batch_size, num_heads, time1, n) = pos_weights.shape
|
||||||
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
|
cols = torch.arange(seq_len)
|
||||||
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
|
indexes = rows + cols
|
||||||
|
pos_weights = pos_weights.reshape(-1, n)
|
||||||
|
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
|
||||||
|
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len)
|
||||||
|
else:
|
||||||
pos_weights = pos_weights.as_strided(
|
pos_weights = pos_weights.as_strided(
|
||||||
(bsz, num_heads, seq_len, seq_len),
|
(bsz, num_heads, seq_len, seq_len),
|
||||||
(
|
(
|
||||||
@ -2275,6 +2285,16 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
# the following .as_strided() expression converts the last axis of pos_weights from relative
|
# the following .as_strided() expression converts the last axis of pos_weights from relative
|
||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# not, but let this code define which way round it is supposed to be.
|
# not, but let this code define which way round it is supposed to be.
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
(batch_size, num_heads, time1, n) = pos_weights.shape
|
||||||
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
|
cols = torch.arange(kv_len)
|
||||||
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
|
indexes = rows + cols
|
||||||
|
pos_weights = pos_weights.reshape(-1, n)
|
||||||
|
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
|
||||||
|
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len)
|
||||||
|
else:
|
||||||
pos_weights = pos_weights.as_strided(
|
pos_weights = pos_weights.as_strided(
|
||||||
(bsz, num_heads, seq_len, kv_len),
|
(bsz, num_heads, seq_len, kv_len),
|
||||||
(
|
(
|
||||||
|
@ -22,5 +22,6 @@ typeguard==2.13.3
|
|||||||
multi_quantization
|
multi_quantization
|
||||||
|
|
||||||
onnx
|
onnx
|
||||||
|
onnxmltools
|
||||||
onnxruntime
|
onnxruntime
|
||||||
kaldifst
|
kaldifst
|
||||||
|
Loading…
x
Reference in New Issue
Block a user