Merge branch 'k2-fsa:master' into repeat-k

This commit is contained in:
Yifan Yang 2023-02-06 16:55:06 +08:00 committed by GitHub
commit 51812b2ddc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 908 additions and 46 deletions

View File

@ -33,6 +33,16 @@ ln -s pretrained.pt epoch-99.pt
ls -lh *.pt
popd
log "Test exporting to ONNX format"
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--fp16 \
--onnx 1
log "Export to torchscript model"
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir $repo/exp \

View File

@ -39,7 +39,7 @@ concurrency:
jobs:
run_librispeech_2022_12_29_zipformer_streaming:
if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event.label.name == 'streaming-zipformer' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:

View File

View File

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
# Zengwei Yao)
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -18,7 +19,7 @@
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple
import torch
import torch.nn as nn
@ -44,6 +45,7 @@ class FrameReducer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
ctc_output: torch.Tensor,
y_lens: Optional[torch.Tensor] = None,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@ -55,6 +57,9 @@ class FrameReducer(nn.Module):
`x` before padding.
ctc_output:
The CTC output with shape [N, T, vocab_size].
y_lens:
A tensor of shape (batch_size,) containing the number of frames in
`y` before padding.
blank_id:
The blank id of ctc_output.
Returns:
@ -64,15 +69,45 @@ class FrameReducer(nn.Module):
A tensor of shape (batch_size,) containing the number of frames in
`out` before padding.
"""
N, T, C = x.size()
padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
if y_lens is not None:
# Limit the maximum number of reduced frames
limit_lens = T - y_lens
max_limit_len = limit_lens.max().int()
fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len
).indices
T = (
torch.arange(max_limit_len)
.expand_as(
fake_limit_indexes,
)
.to(device=x.device)
)
T = torch.remainder(T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, T)
limit_mask = torch.full_like(
non_blank_mask,
False,
device=x.device,
).scatter_(1, limit_indexes, True)
non_blank_mask = non_blank_mask | ~limit_mask
out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max()
pad_lens_list = torch.full_like(out_lens, max_len.item()) - out_lens
pad_lens_list = (
torch.full_like(
out_lens,
max_len.item(),
device=x.device,
)
- out_lens
)
max_pad_len = pad_lens_list.max()
out = F.pad(x, (0, 0, 0, max_pad_len))
@ -82,26 +117,30 @@ class FrameReducer(nn.Module):
out = out[total_valid_mask].reshape(N, -1, C)
return out.to(device=x.device), out_lens.to(device=x.device)
return out, out_lens
if __name__ == "__main__":
import time
from torch.nn.utils.rnn import pad_sequence
test_times = 10000
device = "cuda:0"
frame_reducer = FrameReducer()
# non zero case
x = torch.ones(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.log(torch.randn(15, 498, 500, dtype=torch.float32))
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.log(
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
)
avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
@ -109,14 +148,17 @@ if __name__ == "__main__":
print(avg_time / test_times)
# all zero case
x = torch.zeros(15, 498, 384, dtype=torch.float32)
x_lens = torch.tensor([498] * 15, dtype=torch.int64)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32)
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output)
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)

View File

View File

@ -131,6 +131,10 @@ class Transducer(nn.Module):
# compute ctc log-probs
ctc_output = self.ctc_output(encoder_out)
# y_lens
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
# blank skip
blank_id = self.decoder.blank_id
@ -146,16 +150,14 @@ class Transducer(nn.Module):
encoder_out,
x_lens,
ctc_output,
y_lens,
blank_id,
)
else:
encoder_out_fr = encoder_out
x_lens_fr = x_lens
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
# sos_y
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.

View File

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
@ -35,7 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7_ctc_bs/exp \
--full-libri 1 \
--max-duration 550
--max-duration 750
"""

View File

@ -72,25 +72,81 @@ Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29
# You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp
(3) Export to ONNX format with pretrained.pt
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
--fp16 \
--onnx 1
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Check
https://github.com/k2-fsa/sherpa-onnx
for how to use the exported models outside of icefall.
(4) Export to ONNX format for triton server
cd ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp
ln -s pretrained.pt epoch-999.pt
./pruned_transducer_stateless7_streaming/export.py \
--exp-dir ./icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--use-averaged-model False \
--epoch 999 \
--avg 1 \
--fp16 \
--onnx-triton 1 \
--onnx 1
It will generate the following files in the given `exp_dir`.
Check `onnx_check.py` for how to use them.
- encoder.onnx
- decoder.onnx
- joiner.onnx
Check
https://github.com/k2-fsa/sherpa/tree/master/triton
for how to use the exported models outside of icefall.
"""
import argparse
import logging
from pathlib import Path
import onnxruntime
import sentencepiece as spm
import torch
import torch.nn as nn
from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import stack_states
from icefall.checkpoint import (
average_checkpoints,
@ -172,6 +228,42 @@ def get_parser():
""",
)
parser.add_argument(
"--onnx",
type=str2bool,
default=False,
help="""If True, --jit is ignored and it exports the model
to onnx format. It will generate the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
- joiner_encoder_proj.onnx
- joiner_decoder_proj.onnx
Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them.
""",
)
parser.add_argument(
"--onnx-triton",
type=str2bool,
default=False,
help="""If True, --onnx would export model into the following files:
- encoder.onnx
- decoder.onnx
- joiner.onnx
These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton.
""",
)
parser.add_argument(
"--fp16",
action="store_true",
help="whether to export fp16 onnx model, default false",
)
parser.add_argument(
"--context-size",
type=int,
@ -184,6 +276,391 @@ def get_parser():
return parser
def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True):
for a, b in zip(xlist, blist):
try:
torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
except AssertionError as error:
if tolerate_small_mismatch:
print("small mismatch detected", error)
else:
return False
return True
def export_encoder_model_onnx(
encoder_model: nn.Module,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T, C)
- encoder_out_lens, a tensor of shape (N,)
Note: The warmup argument is fixed to 1.
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
batch_size = 17
seq_len = 101
torch.manual_seed(0)
x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32)
x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64)
# encoder_model = torch.jit.script(encoder_model)
# It throws the following error for the above statement
#
# RuntimeError: Exporting the operator __is_ to ONNX opset version
# 11 is not supported. Please feel free to request support or
# submit a pull request on PyTorch GitHub.
#
# I cannot find which statement causes the above error.
# torch.onnx.export() will use torch.jit.trace() internally, which
# works well for the current reworked model
initial_states = [encoder_model.get_init_state() for _ in range(batch_size)]
states = stack_states(initial_states)
left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks
encoder_attention_dim = encoder_model.encoders[0].attention_dim
len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15
avg_cache = torch.cat(
states[encoder_model.num_encoders : 2 * encoder_model.num_encoders]
).transpose(
0, 1
) # [B,15,384]
cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose(
0, 1
) # [B,2*15,384,cnn_kernel-1]
pad_tensors = [
torch.nn.functional.pad(
tensor,
(
0,
encoder_attention_dim - tensor.shape[-1],
0,
0,
0,
left_context_len - tensor.shape[1],
0,
0,
),
)
for tensor in states[
2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders
]
]
attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
encoder_model_wrapper = OnnxStreamingEncoder(encoder_model)
torch.onnx.export(
encoder_model_wrapper,
(x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"x",
"x_lens",
"len_cache",
"avg_cache",
"attn_cache",
"cnn_cache",
],
output_names=[
"encoder_out",
"encoder_out_lens",
"new_len_cache",
"new_avg_cache",
"new_attn_cache",
"new_cnn_cache",
],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
"len_cache": {0: "N"},
"avg_cache": {0: "N"},
"attn_cache": {0: "N"},
"cnn_cache": {0: "N"},
"new_len_cache": {0: "N"},
"new_avg_cache": {0: "N"},
"new_attn_cache": {0: "N"},
"new_cnn_cache": {0: "N"},
},
)
logging.info(f"Saved to {encoder_filename}")
# Test onnx encoder with torch native encoder
encoder_model.eval()
(
encoder_out_torch,
encoder_out_lens_torch,
new_states_torch,
) = encoder_model.streaming_forward(
x=x,
x_lens=x_lens,
states=states,
)
ort_session = onnxruntime.InferenceSession(
str(encoder_filename), providers=["CPUExecutionProvider"]
)
ort_inputs = {
"x": x.numpy(),
"x_lens": x_lens.numpy(),
"len_cache": len_cache.numpy(),
"avg_cache": avg_cache.numpy(),
"attn_cache": attn_cache.numpy(),
"cnn_cache": cnn_cache.numpy(),
}
ort_outs = ort_session.run(None, ort_inputs)
assert test_acc(
[encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2]
)
logging.info(f"{encoder_filename} acc test succeeded.")
def export_decoder_model_onnx(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
need_pad = False # Always False, so we can use torch.jit.trace() here
# Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script()
# in this case
torch.onnx.export(
decoder_model,
(y, need_pad),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y", "need_pad"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_decoder_model_onnx_triton(
decoder_model: nn.Module,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, 1, C)
Note: The argument need_pad is fixed to False.
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
decoder_model = TritonOnnxDecoder(decoder_model)
torch.onnx.export(
decoder_model,
(y,),
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_filename}")
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
- projected_decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
The exported encoder_proj model has one input:
- encoder_out: a tensor of shape (N, encoder_out_dim)
and produces one output:
- projected_encoder_out: a tensor of shape (N, joiner_dim)
The exported decoder_proj model has one input:
- decoder_out: a tensor of shape (N, decoder_out_dim)
and produces one output:
- projected_decoder_out: a tensor of shape (N, joiner_dim)
"""
encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx")
decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx")
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
joiner_dim = joiner_model.decoder_proj.weight.shape[0]
projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32)
project_input = False
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out, project_input),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
"project_input",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.encoder_proj,
encoder_out,
encoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out"],
output_names=["projected_encoder_out"],
dynamic_axes={
"encoder_out": {0: "N"},
"projected_encoder_out": {0: "N"},
},
)
logging.info(f"Saved to {encoder_proj_filename}")
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model.decoder_proj,
decoder_out,
decoder_proj_filename,
verbose=False,
opset_version=opset_version,
input_names=["decoder_out"],
output_names=["projected_decoder_out"],
dynamic_axes={
"decoder_out": {0: "N"},
"projected_decoder_out": {0: "N"},
},
)
logging.info(f"Saved to {decoder_proj_filename}")
def export_joiner_model_onnx_triton(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported model has two inputs:
- encoder_out: a tensor of shape (N, encoder_out_dim)
- decoder_out: a tensor of shape (N, decoder_out_dim)
and has one output:
- joiner_out: a tensor of shape (N, vocab_size)
Note: The argument project_input is fixed to True. A user should not
project the encoder_out/decoder_out by himself/herself. The exported joiner
will do that for the user.
"""
encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
joiner_model = TritonOnnxJoiner(joiner_model)
# Note: It uses torch.jit.trace() internally
torch.onnx.export(
joiner_model,
(encoder_out, decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=["encoder_out", "decoder_out"],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
logging.info(f"Saved to {joiner_filename}")
@torch.no_grad()
def main():
args = get_parser().parse_args()
@ -292,7 +769,87 @@ def main():
model.to("cpu")
model.eval()
if params.jit is True:
if params.onnx:
convert_scaled_to_non_scaled(model, inplace=True)
opset_version = 13
logging.info("Exporting to onnx format")
encoder_filename = params.exp_dir / "encoder.onnx"
export_encoder_model_onnx(
model.encoder,
encoder_filename,
opset_version=opset_version,
)
if not params.onnx_triton:
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
else:
decoder_filename = params.exp_dir / "decoder.onnx"
export_decoder_model_onnx_triton(
model.decoder,
decoder_filename,
opset_version=opset_version,
)
joiner_filename = params.exp_dir / "joiner.onnx"
export_joiner_model_onnx_triton(
model.joiner,
joiner_filename,
opset_version=opset_version,
)
if params.fp16:
try:
import onnxmltools
from onnxmltools.utils.float16_converter import convert_float_to_float16
except ImportError:
print("Please install onnxmltools!")
import sys
sys.exit(1)
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model)
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx"
export_onnx_fp16(encoder_filename, encoder_fp16_filename)
decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx"
export_onnx_fp16(decoder_filename, decoder_fp16_filename)
joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx"
export_onnx_fp16(joiner_filename, joiner_fp16_filename)
if not params.onnx_triton:
encoder_proj_filename = str(joiner_filename).replace(
".onnx", "_encoder_proj.onnx"
)
encoder_proj_fp16_filename = (
params.exp_dir / "joiner_encoder_proj_fp16.onnx"
)
export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename)
decoder_proj_filename = str(joiner_filename).replace(
".onnx", "_decoder_proj.onnx"
)
decoder_proj_fp16_filename = (
params.exp_dir / "joiner_decoder_proj_fp16.onnx"
)
export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename)
elif params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.

View File

@ -0,0 +1,231 @@
from typing import Optional, Tuple
import torch
class OnnxStreamingEncoder(torch.nn.Module):
"""This class warps the streaming Zipformer to reduce the number of
state tensors for onnx.
https://github.com/k2-fsa/icefall/pull/831
"""
def __init__(self, encoder):
"""
Args:
encoder: A Instance of Zipformer Class
"""
super().__init__()
self.model = encoder
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
len_cache: torch.tensor,
avg_cache: torch.tensor,
attn_cache: torch.tensor,
cnn_cache: torch.tensor,
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
len_cache:
The cached numbers of past frames.
avg_cache:
The cached average tensors.
attn_cache:
The cached key tensors of the first attention modules.
The cached value tensors of the first attention modules.
The cached value tensors of the second attention modules.
cnn_cache:
The cached left contexts of the first convolution modules.
The cached left contexts of the second convolution modules.
Returns:
Return a tuple containing 2 tensors:
"""
num_encoder_layers = []
encoder_attention_dims = []
states = []
for i, encoder in enumerate(self.model.encoders):
num_encoder_layers.append(encoder.num_layers)
encoder_attention_dims.append(encoder.attention_dim)
len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B]
offset = 0
for num_layer in num_encoder_layers:
states.append(len_cache[offset : offset + num_layer])
offset += num_layer
avg_cache = avg_cache.transpose(0, 1) # [15, B, 384]
offset = 0
for num_layer in num_encoder_layers:
states.append(avg_cache[offset : offset + num_layer])
offset += num_layer
attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192]
left_context_len = attn_cache.shape[1]
offset = 0
for i, num_layer in enumerate(num_encoder_layers):
ds = self.model.zipformer_downsampling_factors[i]
states.append(
attn_cache[offset : offset + num_layer, : left_context_len // ds]
)
offset += num_layer
for i, num_layer in enumerate(num_encoder_layers):
encoder_attention_dim = encoder_attention_dims[i]
ds = self.model.zipformer_downsampling_factors[i]
states.append(
attn_cache[
offset : offset + num_layer,
: left_context_len // ds,
:,
: encoder_attention_dim // 2,
]
)
offset += num_layer
for i, num_layer in enumerate(num_encoder_layers):
ds = self.model.zipformer_downsampling_factors[i]
states.append(
attn_cache[
offset : offset + num_layer,
: left_context_len // ds,
:,
: encoder_attention_dim // 2,
]
)
offset += num_layer
cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1]
offset = 0
for num_layer in num_encoder_layers:
states.append(cnn_cache[offset : offset + num_layer])
offset += num_layer
for num_layer in num_encoder_layers:
states.append(cnn_cache[offset : offset + num_layer])
offset += num_layer
encoder_out, encoder_out_lens, new_states = self.model.streaming_forward(
x=x,
x_lens=x_lens,
states=states,
)
new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose(
0, 1
) # [B,15]
new_avg_cache = torch.cat(
states[self.model.num_encoders : 2 * self.model.num_encoders]
).transpose(
0, 1
) # [B,15,384]
new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose(
0, 1
) # [B,2*15,384,cnn_kernel-1]
assert len(set(encoder_attention_dims)) == 1
pad_tensors = [
torch.nn.functional.pad(
tensor,
(
0,
encoder_attention_dims[0] - tensor.shape[-1],
0,
0,
0,
left_context_len - tensor.shape[1],
0,
0,
),
)
for tensor in states[
2 * self.model.num_encoders : 5 * self.model.num_encoders
]
]
new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192]
return (
encoder_out,
encoder_out_lens,
new_len_cache,
new_avg_cache,
new_attn_cache,
new_cnn_cache,
)
class TritonOnnxDecoder(torch.nn.Module):
"""This class warps the Decoder in decoder.py
to remove the scalar input "need_pad".
Triton currently doesn't support scalar input.
https://github.com/triton-inference-server/server/issues/2333
"""
def __init__(
self,
decoder: torch.nn.Module,
):
"""
Args:
decoder: A instance of Decoder
"""
super().__init__()
self.model = decoder
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
# False to not pad the input. Should be False during inference.
need_pad = False
return self.model(y, need_pad)
class TritonOnnxJoiner(torch.nn.Module):
"""This class warps the Joiner in joiner.py
to remove the scalar input "project_input".
Triton currently doesn't support scalar input.
https://github.com/triton-inference-server/server/issues/2333
"project_input" is set to True.
Triton solutions only need export joiner to a single joiner.onnx.
"""
def __init__(
self,
joiner: torch.nn.Module,
):
super().__init__()
self.model = joiner
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
# Apply input projections encoder_proj and decoder_proj.
project_input = True
return self.model(encoder_out, decoder_out, project_input)

View File

@ -2084,6 +2084,16 @@ class RelPositionMultiheadAttention(nn.Module):
# the following .as_strided() expression converts the last axis of pos_weights from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing():
(batch_size, num_heads, time1, n) = pos_weights.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_weights = pos_weights.reshape(-1, n)
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len)
else:
pos_weights = pos_weights.as_strided(
(bsz, num_heads, seq_len, seq_len),
(
@ -2275,6 +2285,16 @@ class RelPositionMultiheadAttention(nn.Module):
# the following .as_strided() expression converts the last axis of pos_weights from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing():
(batch_size, num_heads, time1, n) = pos_weights.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(kv_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_weights = pos_weights.reshape(-1, n)
pos_weights = torch.gather(pos_weights, dim=1, index=indexes)
pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len)
else:
pos_weights = pos_weights.as_strided(
(bsz, num_heads, seq_len, kv_len),
(

View File

@ -22,5 +22,6 @@ typeguard==2.13.3
multi_quantization
onnx
onnxmltools
onnxruntime
kaldifst