Add streaming onnx export for zipformer (#831)

* add streaming onnx export for zipformer

* update triton support

* add comments

* add ci test

* add onnxmltools for fp16 onnx export
This commit is contained in:
Yuekai Zhang 2023-02-06 10:37:07 +08:00 committed by GitHub
parent 029c8566e4
commit bf5f0342a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 843 additions and 24 deletions

View File

@ -33,6 +33,16 @@ ln -s pretrained.pt epoch-99.pt
ls -lh *.pt
popd
log "Test exporting to ONNX format"
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--fp16 \
--onnx 1
log "Export to torchscript model"
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir $repo/exp \

View File

@ -39,7 +39,7 @@ concurrency:
jobs:
run_librispeech_2022_12_29_zipformer_streaming:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:

View File

@ -72,25 +72,81 @@ Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
(3) Export to ONNX format with pretrained.pt
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
--fp16 \
--onnx 1
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Check
https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
(4) Export to ONNX format for triton server
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
--fp16 \
--onnx-triton 1 \
--onnx 1
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
Check
https://github.com/k2-fsa/sherpa/tree/master/triton
for how to use the exported models outside of icefall.
"""
import argparse
import logging
from pathlib import Path
import onnxruntime
import sentencepiece as spm
import torch
import torch.nn as nn
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import stack_states
from icefall.checkpoint import (
average_checkpoints,
@ -172,6 +228,42 @@ def get_parser():
""",
)
parser.add_argument(
"--onnx",
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--onnx-triton",
type=str2bool,
default=False,
help="""If True, --onnx would export model into the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton.
""",
)
parser.add_argument(
"--fp16",
action="store_true",
help="whether to export fp16 onnx model, default false",
)
parser.add_argument(
"--context-size",
type=int,
@ -184,6 +276,391 @@ def get_parser():
return parser
def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
for a, b in zip(xlist, blist):
try:
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
except AssertionError as error:
if tolerate_small_mismatch:
print("small mismatch detected", error)
else:
return False
return True
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
batch_size = 17
seq_len = 101
torch.manual_seed(0)
x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32)
x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
initial_states = [encoder_model.get_init_state() for _ in range(batch_size)]
states = stack_states(initial_states)
left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks
encoder_attention_dim = encoder_model.encoders[0].attention_dim
len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15
avg_cache = torch.cat(
states[encoder_model.num_encoders : 2 * encoder_model.num_encoders]
).transpose(
0, 1
) # [B,15,384]
cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose(
0, 1
) # [B,2*15,384,cnn_kernel-1]
pad_tensors = [
torch.nn.functional.pad(
tensor,
(
0,
encoder_attention_dim - tensor.shape[-1],
0,
0,
0,
left_context_len - tensor.shape[1],
0,
0,
),
)
for tensor in states[
2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders
]
]
attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
encoder_model_wrapper = OnnxStreamingEncoder(encoder_model)
torch.onnx.export(
encoder_model_wrapper,
(x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"x",
"x_lens",
"len_cache",
"avg_cache",
"attn_cache",
"cnn_cache",
],
output_names=[
"encoder_out",
"encoder_out_lens",
"new_len_cache",
"new_avg_cache",
"new_attn_cache",
"new_cnn_cache",
],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
"len_cache": {0: "N"},
"avg_cache": {0: "N"},
"attn_cache": {0: "N"},
"cnn_cache": {0: "N"},
"new_len_cache": {0: "N"},
"new_avg_cache": {0: "N"},
"new_attn_cache": {0: "N"},
"new_cnn_cache": {0: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
# Test onnx encoder with torch native encoder
encoder_model.eval()
(
encoder_out_torch,
encoder_out_lens_torch,
new_states_torch,
) = encoder_model.streaming_forward(
x=x,
x_lens=x_lens,
states=states,
)
ort_session = onnxruntime.InferenceSession(
str(encoder_filename), providers=["CPUExecutionProvider"]
)
ort_inputs = {
"x": x.numpy(),
"x_lens": x_lens.numpy(),
"len_cache": len_cache.numpy(),
"avg_cache": avg_cache.numpy(),
"attn_cache": attn_cache.numpy(),
"cnn_cache": cnn_cache.numpy(),
}
ort_outs = ort_session.run(None, ort_inputs)
assert test_acc(
[encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2]
)
logging.info(f"{encoder_filename} acc test succeeded.")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_decoder_model_onnx_triton(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
decoder_model = TritonOnnxDecoder(decoder_model)
torch.onnx.export(
decoder_model,
(y,),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_proj_filename}")
def export_joiner_model_onnx_triton(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported model has two inputs:
- encoder_out: a tensor of shape (N, encoder_out_dim)
- decoder_out: a tensor of shape (N, decoder_out_dim)
and has one output:
- joiner_out: a tensor of shape (N, vocab_size)
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
joiner_model = TritonOnnxJoiner(joiner_model)
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(encoder_out, decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out"],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
@ -292,7 +769,87 @@ def main():
model.to("cpu")
model.eval()
if params.jit is True:
if params.onnx:
convert_scaled_to_non_scaled(model, inplace=True)
opset_version = 13
logging.info("Exporting to onnx format")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
if not params.onnx_triton:
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
else:
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx_triton(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx_triton(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
if params.fp16:
try:
import onnxmltools
from onnxmltools.utils.float16_converter import convert_float_to_float16
except ImportError:
print("Please install onnxmltools!")
import sys
sys.exit(1)
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model)
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx"
export_onnx_fp16(encoder_filename, encoder_fp16_filename)
decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx"
export_onnx_fp16(decoder_filename, decoder_fp16_filename)
joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx"
export_onnx_fp16(joiner_filename, joiner_fp16_filename)
if not params.onnx_triton:
encoder_proj_filename = str(joiner_filename).replace(
".onnx", "_encoder_proj.onnx"
)
encoder_proj_fp16_filename = (
params.exp_dir / "joiner_encoder_proj_fp16.onnx"
)
export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename)
decoder_proj_filename = str(joiner_filename).replace(
".onnx", "_decoder_proj.onnx"
)
decoder_proj_fp16_filename = (
params.exp_dir / "joiner_decoder_proj_fp16.onnx"
)
export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename)
elif params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.

View File

@ -0,0 +1,231 @@
from typing import Optional, Tuple
import torch
class OnnxStreamingEncoder(torch.nn.Module):
"""This class warps the streaming Zipformer to reduce the number of
state tensors for onnx.
https://github.com/k2-fsa/icefall/pull/831
"""
def __init__(self, encoder):
"""
Args:
encoder: A Instance of Zipformer Class
"""
super().__init__()
self.model = encoder
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
len_cache: torch.tensor,
avg_cache: torch.tensor,
attn_cache: torch.tensor,
cnn_cache: torch.tensor,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
len_cache:
The cached numbers of past frames.
avg_cache:
The cached average tensors.
attn_cache:
The cached key tensors of the first attention modules.
The cached value tensors of the first attention modules.
The cached value tensors of the second attention modules.
cnn_cache:
The cached left contexts of the first convolution modules.
The cached left contexts of the second convolution modules.
Returns:
Return a tuple containing 2 tensors:
"""
num_encoder_layers = []
encoder_attention_dims = []
states = []
for i, encoder in enumerate(self.model.encoders):
num_encoder_layers.append(encoder.num_layers)
encoder_attention_dims.append(encoder.attention_dim)
len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B]
offset = 0
for num_layer in num_encoder_layers:
states.append(len_cache[offset : offset + num_layer])
offset += num_layer
avg_cache = avg_cache.transpose(0, 1) # [15, B, 384]
offset = 0
for num_layer in num_encoder_layers:
states.append(avg_cache[offset : offset + num_layer])
offset += num_layer
attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192]
left_context_len = attn_cache.shape[1]
offset = 0
for i, num_layer in enumerate(num_encoder_layers):
ds = self.model.zipformer_downsampling_factors[i]
states.append(
attn_cache[offset : offset + num_layer, : left_context_len // ds]
)
offset += num_layer
for i, num_layer in enumerate(num_encoder_layers):
encoder_attention_dim = encoder_attention_dims[i]
ds = self.model.zipformer_downsampling_factors[i]
states.append(
attn_cache[
offset : offset + num_layer,
: left_context_len // ds,
:,
: encoder_attention_dim // 2,
]
)
offset += num_layer
for i, num_layer in enumerate(num_encoder_layers):
ds = self.model.zipformer_downsampling_factors[i]
states.append(
attn_cache[
offset : offset + num_layer,
: left_context_len // ds,
:,
: encoder_attention_dim // 2,
]
)
offset += num_layer
cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1]
offset = 0
for num_layer in num_encoder_layers:
states.append(cnn_cache[offset : offset + num_layer])
offset += num_layer
for num_layer in num_encoder_layers:
states.append(cnn_cache[offset : offset + num_layer])
offset += num_layer
encoder_out, encoder_out_lens, new_states = self.model.streaming_forward(
x=x,
x_lens=x_lens,
states=states,
)
new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose(
0, 1
) # [B,15]
new_avg_cache = torch.cat(
states[self.model.num_encoders : 2 * self.model.num_encoders]
).transpose(
0, 1
) # [B,15,384]
new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose(
0, 1
) # [B,2*15,384,cnn_kernel-1]
assert len(set(encoder_attention_dims)) == 1
pad_tensors = [
torch.nn.functional.pad(
tensor,
(
0,
encoder_attention_dims[0] - tensor.shape[-1],
0,
0,
0,
left_context_len - tensor.shape[1],
0,
0,
),
)
for tensor in states[
2 * self.model.num_encoders : 5 * self.model.num_encoders
]
]
new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
return (
encoder_out,
encoder_out_lens,
new_len_cache,
new_avg_cache,
new_attn_cache,
new_cnn_cache,
)
class TritonOnnxDecoder(torch.nn.Module):
"""This class warps the Decoder in decoder.py
to remove the scalar input "need_pad".
Triton currently doesn't support scalar input.
https://github.com/triton-inference-server/server/issues/2333
"""
def __init__(
self,
decoder: torch.nn.Module,
):
"""
Args:
decoder: A instance of Decoder
"""
super().__init__()
self.model = decoder
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
# False to not pad the input. Should be False during inference.
need_pad = False
return self.model(y, need_pad)
class TritonOnnxJoiner(torch.nn.Module):
"""This class warps the Joiner in joiner.py
to remove the scalar input "project_input".
Triton currently doesn't support scalar input.
https://github.com/triton-inference-server/server/issues/2333
"project_input" is set to True.
Triton solutions only need export joiner to a single joiner.onnx.
"""
def __init__(
self,
joiner: torch.nn.Module,
):
super().__init__()
self.model = joiner
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
# Apply input projections encoder_proj and decoder_proj.
project_input = True
return self.model(encoder_out, decoder_out, project_input)

View File

@ -2084,16 +2084,26 @@ class RelPositionMultiheadAttention(nn.Module):
# the following .as_strided() expression converts the last axis of pos_weights from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
pos_weights = pos_weights.as_strided(
(bsz, num_heads, seq_len, seq_len),
(
pos_weights.stride(0),
pos_weights.stride(1),
pos_weights.stride(2) - pos_weights.stride(3),
pos_weights.stride(3),
),
storage_offset=pos_weights.stride(3) * (seq_len - 1),
)
if torch.jit.is_tracing():
(batch_size, num_heads, time1, n) = pos_weights.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_weights = pos_weights.reshape(-1, n)
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len)
else:
pos_weights = pos_weights.as_strided(
(bsz, num_heads, seq_len, seq_len),
(
pos_weights.stride(0),
pos_weights.stride(1),
pos_weights.stride(2) - pos_weights.stride(3),
pos_weights.stride(3),
),
storage_offset=pos_weights.stride(3) * (seq_len - 1),
)
# caution: they are really scores at this point.
attn_output_weights = torch.matmul(q, k) + pos_weights
@ -2275,16 +2285,26 @@ class RelPositionMultiheadAttention(nn.Module):
# the following .as_strided() expression converts the last axis of pos_weights from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
pos_weights = pos_weights.as_strided(
(bsz, num_heads, seq_len, kv_len),
(
pos_weights.stride(0),
pos_weights.stride(1),
pos_weights.stride(2) - pos_weights.stride(3),
pos_weights.stride(3),
),
storage_offset=pos_weights.stride(3) * (seq_len - 1),
)
if torch.jit.is_tracing():
(batch_size, num_heads, time1, n) = pos_weights.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(kv_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_weights = pos_weights.reshape(-1, n)
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len)
else:
pos_weights = pos_weights.as_strided(
(bsz, num_heads, seq_len, kv_len),
(
pos_weights.stride(0),
pos_weights.stride(1),
pos_weights.stride(2) - pos_weights.stride(3),
pos_weights.stride(3),
),
storage_offset=pos_weights.stride(3) * (seq_len - 1),
)
# caution: they are really scores at this point.
attn_output_weights = torch.matmul(q, k) + pos_weights

View File

@ -22,5 +22,6 @@ typeguard==2.13.3
multi_quantization
onnx
onnxmltools
onnxruntime
kaldifst