add updated zipformer onnx export (#1108)

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
danfu 2023-06-12 14:02:23 +08:00 committed by GitHub
parent b4c38d7547
commit 0cb71ad3bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 2037 additions and 55 deletions

View File

@ -0,0 +1,775 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
The --chunk-size in training is "16,32,64,-1", so we select one of them
(excluding -1) during streaming export. The same applies to `--left-context`,
whose value is "64,128,256,-1".
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool, make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
self.pad_length = 7 + 2 * 3
def forward(
self,
x: torch.Tensor,
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
N = x.size(0)
T = self.chunk_size * 2 + self.pad_length
x_lens = torch.tensor([T] * N, device=x.device)
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
x=x,
x_lens=x_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2)
encoder_states = states[:-2]
logging.info(f"len_encoder_states={len(encoder_states)}")
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = self.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, new_states
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = self.encoder.get_init_states(batch_size, device)
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
states.append(processed_lens)
return states
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
encoder_model.encoder.__class__.forward = (
encoder_model.encoder.__class__.streaming_forward
)
decode_chunk_len = encoder_model.chunk_size * 2
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
T = decode_chunk_len + encoder_model.pad_length
x = torch.rand(1, T, 80, dtype=torch.float32)
init_state = encoder_model.get_init_states()
num_encoders = len(encoder_model.encoder.encoder_dim)
logging.info(f"num_encoders: {num_encoders}")
logging.info(f"len(init_state): {len(init_state)}")
inputs = {}
input_names = ["x"]
outputs = {}
output_names = ["encoder_out"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f"cached_key_{i}"
logging.info(f"{name}.shape: {tensors[0].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f"cached_nonlin_attn_{i}"
logging.info(f"{name}.shape: {tensors[1].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val1_{i}"
logging.info(f"{name}.shape: {tensors[2].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val2_{i}"
logging.info(f"{name}.shape: {tensors[3].shape}")
inputs[name] = {1: "N"}
outputs[f"new_{name}"] = {1: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv1_{i}"
logging.info(f"{name}.shape: {tensors[4].shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv2_{i}"
logging.info(f"{name}.shape: {tensors[5].shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
ds = encoder_model.encoder.downsampling_factor
left_context_len = encoder_model.left_context_len
left_context_len = [left_context_len // k for k in ds]
left_context_len = ",".join(map(str, left_context_len))
query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim))
value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim))
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "streaming zipformer2",
"decode_chunk_len": str(decode_chunk_len), # 32
"T": str(T), # 32+7+2*3=45
"num_encoder_layers": num_encoder_layers,
"encoder_dims": encoder_dims,
"cnn_module_kernels": cnn_module_kernels,
"left_context_len": left_context_len,
"query_head_dims": query_head_dims,
"value_head_dims": value_head_dims,
"num_heads": num_heads,
}
logging.info(f"meta_data: {meta_data}")
for i in range(len(init_state[:-2]) // 6):
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
embed_states = init_state[-2]
name = "embed_states"
logging.info(f"{name}.shape: {embed_states.shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
# (batch_size,)
processed_lens = init_state[-1]
name = "processed_lens"
logging.info(f"{name}.shape: {processed_lens.shape}")
inputs[name] = {0: "N"}
outputs[f"new_{name}"] = {0: "N"}
input_names.append(name)
output_names.append(f"new_{name}")
logging.info(inputs)
logging.info(outputs)
logging.info(input_names)
logging.info(output_names)
torch.onnx.export(
encoder_model,
(x, init_state),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"x": {0: "N"},
"encoder_out": {0: "N"},
**inputs,
**outputs,
},
)
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,624 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script exports a transducer model from PyTorch to ONNX.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
\
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
See ./onnx_pretrained.py and ./onnx_check.py for how to
use the exported ONNX models.
"""
import argparse
import logging
from pathlib import Path
from typing import Dict, Tuple
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import Zipformer2
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool, make_pad_mask
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = value
onnx.save(model, filename)
class OnnxEncoder(nn.Module):
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
def __init__(
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
):
"""
Args:
encoder:
A Zipformer encoder.
encoder_proj:
The projection layer for encoder from the joiner.
"""
super().__init__()
self.encoder = encoder
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Please see the help information of Zipformer.forward
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,). Its dtype is torch.int64
Returns:
Return a tuple containing:
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
- encoder_out_lens, A 1-D tensor of shape (N,)
"""
x, x_lens = self.encoder_embed(x, x_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2)
encoder_out = self.encoder_proj(encoder_out)
# Now encoder_out is of shape (N, T, joiner_dim)
return encoder_out, encoder_out_lens
class OnnxDecoder(nn.Module):
"""A wrapper for Decoder and the decoder_proj from the joiner"""
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
super().__init__()
self.decoder = decoder
self.decoder_proj = decoder_proj
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, context_size).
Returns
Return a 2-D tensor of shape (N, joiner_dim)
"""
need_pad = False
decoder_output = self.decoder(y, need_pad=need_pad)
decoder_output = decoder_output.squeeze(1)
output = self.decoder_proj(decoder_output)
return output
class OnnxJoiner(nn.Module):
"""A wrapper for the joiner"""
def __init__(self, output_linear: nn.Linear):
super().__init__()
self.output_linear = output_linear
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit))
return logit
def export_encoder_model_onnx(
encoder_model: OnnxEncoder,
encoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the given encoder model to ONNX format.
The exported model has two inputs:
- x, a tensor of shape (N, T, C); dtype is torch.float32
- x_lens, a tensor of shape (N,); dtype is torch.int64
and it has two outputs:
- encoder_out, a tensor of shape (N, T', joiner_dim)
- encoder_out_lens, a tensor of shape (N,)
Args:
encoder_model:
The input encoder model
encoder_filename:
The filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
torch.onnx.export(
encoder_model,
(x, x_lens),
encoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["encoder_out", "encoder_out_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"encoder_out": {0: "N", 1: "T"},
"encoder_out_lens": {0: "N"},
},
)
meta_data = {
"model_type": "zipformer2",
"version": "1",
"model_author": "k2-fsa",
"comment": "non-streaming zipformer2",
}
logging.info(f"meta_data: {meta_data}")
add_meta_data(filename=encoder_filename, meta_data=meta_data)
def export_decoder_model_onnx(
decoder_model: OnnxDecoder,
decoder_filename: str,
opset_version: int = 11,
) -> None:
"""Export the decoder model to ONNX format.
The exported model has one input:
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
and has one output:
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
Args:
decoder_model:
The decoder model to be exported.
decoder_filename:
Filename to save the exported ONNX model.
opset_version:
The opset version to use.
"""
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
torch.onnx.export(
decoder_model,
y,
decoder_filename,
verbose=False,
opset_version=opset_version,
input_names=["y"],
output_names=["decoder_out"],
dynamic_axes={
"y": {0: "N"},
"decoder_out": {0: "N"},
},
)
meta_data = {
"context_size": str(context_size),
"vocab_size": str(vocab_size),
}
add_meta_data(filename=decoder_filename, meta_data=meta_data)
def export_joiner_model_onnx(
joiner_model: nn.Module,
joiner_filename: str,
opset_version: int = 11,
) -> None:
"""Export the joiner model to ONNX format.
The exported joiner model has two inputs:
- encoder_out: a tensor of shape (N, joiner_dim)
- decoder_out: a tensor of shape (N, joiner_dim)
and produces one output:
- logit: a tensor of shape (N, vocab_size)
"""
joiner_dim = joiner_model.output_linear.weight.shape[1]
logging.info(f"joiner dim: {joiner_dim}")
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
torch.onnx.export(
joiner_model,
(projected_encoder_out, projected_decoder_out),
joiner_filename,
verbose=False,
opset_version=opset_version,
input_names=[
"encoder_out",
"decoder_out",
],
output_names=["logit"],
dynamic_axes={
"encoder_out": {0: "N"},
"decoder_out": {0: "N"},
"logit": {0: "N"},
},
)
meta_data = {
"joiner_dim": str(joiner_dim),
}
add_meta_data(filename=joiner_filename, meta_data=meta_data)
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder(
encoder=model.encoder,
encoder_embed=model.encoder_embed,
encoder_proj=model.joiner.encoder_proj,
)
decoder = OnnxDecoder(
decoder=model.decoder,
decoder_proj=model.joiner.decoder_proj,
)
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
logging.info(f"encoder parameters: {encoder_num_param}")
logging.info(f"decoder parameters: {decoder_num_param}")
logging.info(f"joiner parameters: {joiner_num_param}")
logging.info(f"total parameters: {total_num_param}")
if params.iter > 0:
suffix = f"iter-{params.iter}"
else:
suffix = f"epoch-{params.epoch}"
suffix += f"-avg-{params.avg}"
opset_version = 13
logging.info("Exporting encoder")
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
export_encoder_model_onnx(
encoder,
encoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported encoder to {encoder_filename}")
logging.info("Exporting decoder")
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
export_decoder_model_onnx(
decoder,
decoder_filename,
opset_version=opset_version,
)
logging.info(f"Exported decoder to {decoder_filename}")
logging.info("Exporting joiner")
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
export_joiner_model_onnx(
joiner,
joiner_filename,
opset_version=opset_version,
)
logging.info(f"Exported joiner to {joiner_filename}")
# Generate int8 quantization models
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
logging.info("Generate int8 quantization models")
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=encoder_filename,
model_output=encoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
quantize_dynamic(
model_input=decoder_filename,
model_output=decoder_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
quantize_dynamic(
model_input=joiner_filename,
model_output=joiner_filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -49,7 +49,7 @@ class Transducer(nn.Module):
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape

View File

@ -0,0 +1,544 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
"""
This script loads ONNX models exported by ./export-onnx-streaming.py
and uses them to decode waves.
We use the pre-trained model from
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
as an example to show how to use this file.
1. Download the pre-trained model
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
ln -s pretrained.pt epoch-99.pt
popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal True \
--chunk-size 16 \
--left-context-frames 64
It will generate the following 3 files inside $repo/exp:
- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
3. Run this file with the exported ONNX models
./zipformer/onnx_pretrained-streaming.py \
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav
Note: Even though this script only supports decoding a single file,
the exported ONNX models do support batch processing.
"""
import argparse
import logging
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import onnxruntime as ort
import torch
import torchaudio
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--encoder-model-filename",
type=str,
required=True,
help="Path to the encoder onnx model. ",
)
parser.add_argument(
"--decoder-model-filename",
type=str,
required=True,
help="Path to the decoder onnx model. ",
)
parser.add_argument(
"--joiner-model-filename",
type=str,
required=True,
help="Path to the joiner onnx model. ",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.""",
)
parser.add_argument(
"sound_file",
type=str,
help="The input sound file to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
class OnnxModel:
def __init__(
self,
encoder_model_filename: str,
decoder_model_filename: str,
joiner_model_filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.init_encoder(encoder_model_filename)
self.init_decoder(decoder_model_filename)
self.init_joiner(joiner_model_filename)
def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
)
self.init_encoder_states()
def init_encoder_states(self, batch_size: int = 1):
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
logging.info(f"encoder_meta={encoder_meta}")
model_type = encoder_meta["model_type"]
assert model_type == "zipformer2", model_type
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
T = int(encoder_meta["T"])
num_encoder_layers = encoder_meta["num_encoder_layers"]
encoder_dims = encoder_meta["encoder_dims"]
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
left_context_len = encoder_meta["left_context_len"]
query_head_dims = encoder_meta["query_head_dims"]
value_head_dims = encoder_meta["value_head_dims"]
num_heads = encoder_meta["num_heads"]
def to_int_list(s):
return list(map(int, s.split(",")))
num_encoder_layers = to_int_list(num_encoder_layers)
encoder_dims = to_int_list(encoder_dims)
cnn_module_kernels = to_int_list(cnn_module_kernels)
left_context_len = to_int_list(left_context_len)
query_head_dims = to_int_list(query_head_dims)
value_head_dims = to_int_list(value_head_dims)
num_heads = to_int_list(num_heads)
logging.info(f"decode_chunk_len: {decode_chunk_len}")
logging.info(f"T: {T}")
logging.info(f"num_encoder_layers: {num_encoder_layers}")
logging.info(f"encoder_dims: {encoder_dims}")
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
logging.info(f"left_context_len: {left_context_len}")
logging.info(f"query_head_dims: {query_head_dims}")
logging.info(f"value_head_dims: {value_head_dims}")
logging.info(f"num_heads: {num_heads}")
num_encoders = len(num_encoder_layers)
self.states = []
for i in range(num_encoders):
num_layers = num_encoder_layers[i]
key_dim = query_head_dims[i] * num_heads[i]
embed_dim = encoder_dims[i]
nonlin_attn_head_dim = 3 * embed_dim // 4
value_dim = value_head_dims[i] * num_heads[i]
conv_left_pad = cnn_module_kernels[i] // 2
for layer in range(num_layers):
cached_key = torch.zeros(
left_context_len[i], batch_size, key_dim
).numpy()
cached_nonlin_attn = torch.zeros(
1, batch_size, left_context_len[i], nonlin_attn_head_dim
).numpy()
cached_val1 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_val2 = torch.zeros(
left_context_len[i], batch_size, value_dim
).numpy()
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
self.states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
self.states.append(processed_lens)
self.num_encoders = num_encoders
self.segment = T
self.offset = decode_chunk_len
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
self.context_size = int(decoder_meta["context_size"])
self.vocab_size = int(decoder_meta["vocab_size"])
logging.info(f"context_size: {self.context_size}")
logging.info(f"vocab_size: {self.vocab_size}")
def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
self.joiner_dim = int(joiner_meta["joiner_dim"])
logging.info(f"joiner_dim: {self.joiner_dim}")
def _build_encoder_input_output(
self,
x: torch.Tensor,
) -> Tuple[Dict[str, np.ndarray], List[str]]:
encoder_input = {"x": x.numpy()}
encoder_output = ["encoder_out"]
def build_inputs_outputs(tensors, i):
assert len(tensors) == 6, len(tensors)
# (downsample_left, batch_size, key_dim)
name = f"cached_key_{i}"
encoder_input[name] = tensors[0]
encoder_output.append(f"new_{name}")
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
name = f"cached_nonlin_attn_{i}"
encoder_input[name] = tensors[1]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val1_{i}"
encoder_input[name] = tensors[2]
encoder_output.append(f"new_{name}")
# (downsample_left, batch_size, value_dim)
name = f"cached_val2_{i}"
encoder_input[name] = tensors[3]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv1_{i}"
encoder_input[name] = tensors[4]
encoder_output.append(f"new_{name}")
# (batch_size, embed_dim, conv_left_pad)
name = f"cached_conv2_{i}"
encoder_input[name] = tensors[5]
encoder_output.append(f"new_{name}")
for i in range(len(self.states[:-2]) // 6):
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
# (batch_size, channels, left_pad, freq)
name = "embed_states"
embed_states = self.states[-2]
encoder_input[name] = embed_states
encoder_output.append(f"new_{name}")
# (batch_size,)
name = "processed_lens"
processed_lens = self.states[-1]
encoder_input[name] = processed_lens
encoder_output.append(f"new_{name}")
return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]):
self.states = states
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
Returns:
Return a 3-D tensor of shape (N, T', joiner_dim) where
T' is usually equal to ((T-7)//2+1)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
out = self.encoder.run(encoder_output_names, encoder_input)
self._update_states(out[1:])
return torch.from_numpy(out[0])
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
"""
Args:
decoder_input:
A 2-D tensor of shape (N, context_size)
Returns:
Return a 2-D tensor of shape (N, joiner_dim)
"""
out = self.decoder.run(
[self.decoder.get_outputs()[0].name],
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
)[0]
return torch.from_numpy(out)
def run_joiner(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
A 2-D tensor of shape (N, joiner_dim)
decoder_out:
A 2-D tensor of shape (N, joiner_dim)
Returns:
Return a 2-D tensor of shape (N, vocab_size)
"""
out = self.joiner.run(
[self.joiner.get_outputs()[0].name],
{
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
},
)[0]
return torch.from_numpy(out)
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
def create_streaming_feature_extractor() -> OnlineFeature:
"""Create a CPU streaming feature extractor.
At present, we assume it returns a fbank feature extractor with
fixed options. In the future, we will support passing in the options
from outside.
Returns:
Return a CPU streaming feature extractor.
"""
opts = FbankOptions()
opts.device = "cpu"
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
return OnlineFbank(opts)
def greedy_search(
model: OnnxModel,
encoder_out: torch.Tensor,
context_size: int,
decoder_out: Optional[torch.Tensor] = None,
hyp: Optional[List[int]] = None,
) -> List[int]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
A 3-D tensor of shape (1, T, joiner_dim)
context_size:
The context size of the decoder model.
decoder_out:
Optional. Decoder output of the previous chunk.
hyp:
Decoding results for previous chunks.
Returns:
Return the decoded results so far.
"""
blank_id = 0
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor([hyp], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
else:
assert hyp is not None, hyp
encoder_out = encoder_out.squeeze(0)
T = encoder_out.size(0)
for t in range(T):
cur_encoder_out = encoder_out[t : t + 1]
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
y = joiner_out.argmax(dim=0).item()
if y != blank_id:
hyp.append(y)
decoder_input = hyp[-context_size:]
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
decoder_out = model.run_decoder(decoder_input)
return hyp, decoder_out
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
logging.info(vars(args))
model = OnnxModel(
encoder_model_filename=args.encoder_model_filename,
decoder_model_filename=args.decoder_model_filename,
joiner_model_filename=args.joiner_model_filename,
)
sample_rate = 16000
logging.info("Constructing Fbank computer")
online_fbank = create_streaming_feature_extractor()
logging.info(f"Reading sound files: {args.sound_file}")
waves = read_sound_files(
filenames=[args.sound_file],
expected_sample_rate=sample_rate,
)[0]
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
wave_samples = torch.cat([waves, tail_padding])
num_processed_frames = 0
segment = model.segment
offset = model.offset
context_size = model.context_size
hyp = None
decoder_out = None
chunk = int(1 * sample_rate) # 1 second
start = 0
while start < wave_samples.numel():
end = min(start + chunk, wave_samples.numel())
samples = wave_samples[start:end]
start += chunk
online_fbank.accept_waveform(
sampling_rate=sample_rate,
waveform=samples,
)
while online_fbank.num_frames_ready - num_processed_frames >= segment:
frames = []
for i in range(segment):
frames.append(online_fbank.get_frame(num_processed_frames + i))
num_processed_frames += offset
frames = torch.cat(frames, dim=0)
frames = frames.unsqueeze(0)
encoder_out = model.run_encoder(frames)
hyp, decoder_out = greedy_search(
model,
encoder_out,
context_size,
decoder_out,
hyp,
)
symbol_table = k2.SymbolTable.from_file(args.tokens)
text = ""
for i in hyp[context_size:]:
text += symbol_table[i]
text = text.replace("", " ").strip()
logging.info(args.sound_file)
logging.info(text)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/onnx_pretrained.py

View File

@ -26,6 +26,18 @@ import torch.nn as nn
from torch import Tensor
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
# 14 is not supported. Please feel free to request support or submit
# a pull request on PyTorch GitHub.
#
# The following function is to solve the above error when exporting
# models to ONNX via torch.jit.trace()
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
if not torch.jit.is_tracing():
return torch.logaddexp(x, y)
else:
return (x.exp() + y.exp()).log()
class PiecewiseLinear(object):
"""
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module):
def __float__(self):
batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting():
if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return float(self.default)
else:
ans = self.schedule(self.batch_count)
@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function):
def softmax(x: Tensor, dim: int):
if not x.requires_grad or torch.jit.is_scripting():
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
return x.softmax(dim=dim)
return SoftmaxFunction.apply(x, dim)
@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module):
self.alpha = alpha
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return x
return scale_grad(x, self.alpha)
@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor,
def _no_op(x: Tensor) -> Tensor:
if (torch.jit.is_scripting()):
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x
else:
# a no-op function that will have a node in the autograd graph,
@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x)
@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-L activation.
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad:
return k2.swoosh_l_forward(x)
else:
@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-R activation.
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
if not x.requires_grad:
return k2.swoosh_r_forward(x)
else:

View File

@ -27,6 +27,7 @@ from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
is_onnx: bool = False,
):
"""
Args:
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return:
Return a model without scaled layers.
"""
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items():
if "." in k:

View File

@ -100,7 +100,7 @@ class ConvNeXt(nn.Module):
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module):
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():

View File

@ -133,6 +133,7 @@ class Zipformer2(EncoderInterface):
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
num_encoder_layers = _to_tuple(num_encoder_layers)
self.num_encoder_layers = num_encoder_layers
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
pos_head_dim = _to_tuple(pos_head_dim)
@ -258,7 +259,7 @@ class Zipformer2(EncoderInterface):
if not self.causal:
return -1, -1
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.chunk_size) == 1, self.chunk_size
chunk_size = self.chunk_size[0]
else:
@ -267,7 +268,7 @@ class Zipformer2(EncoderInterface):
if chunk_size == -1:
left_context_chunks = -1
else:
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.left_context_frames) == 1, self.left_context_frames
left_context_frames = self.left_context_frames[0]
else:
@ -301,14 +302,14 @@ class Zipformer2(EncoderInterface):
of frames in `embeddings` before padding.
"""
outputs = []
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
feature_masks = [1.0] * len(self.encoder_dim)
else:
feature_masks = self.get_feature_masks(x)
chunk_size, left_context_chunks = self.get_chunk_info()
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
# Not support exporting a model for simulating streaming decoding
attn_mask = None
else:
@ -334,7 +335,7 @@ class Zipformer2(EncoderInterface):
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
lengths = (x_lens + 1) // 2
else:
with warnings.catch_warnings():
@ -372,7 +373,7 @@ class Zipformer2(EncoderInterface):
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
# c is chunk index for each frame, shape (seq_len,)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
c = t // chunk_size
else:
with warnings.catch_warnings():
@ -650,7 +651,7 @@ class Zipformer2EncoderLayer(nn.Module):
)
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return None
batch_size = x.shape[1]
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
@ -695,7 +696,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_orig = src
# dropout rate for non-feedforward submodules
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
attention_skip_rate = 0.0
else:
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
@ -713,7 +714,7 @@ class Zipformer2EncoderLayer(nn.Module):
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif not self.training and random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
@ -732,7 +733,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
@ -740,7 +741,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_key_padding_mask=src_key_padding_mask),
conv_skip_rate)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
ff2_skip_rate = 0.0
else:
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
@ -754,7 +755,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
@ -762,7 +763,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_key_padding_mask=src_key_padding_mask),
conv_skip_rate)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
ff3_skip_rate = 0.0
else:
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
@ -968,7 +969,7 @@ class Zipformer2Encoder(nn.Module):
pos_emb = self.encoder_pos(src)
output = src
if not torch.jit.is_scripting():
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
output = output * feature_mask
for i, mod in enumerate(self.layers):
@ -980,7 +981,7 @@ class Zipformer2Encoder(nn.Module):
src_key_padding_mask=src_key_padding_mask,
)
if not torch.jit.is_scripting():
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
output = output * feature_mask
return output
@ -1073,7 +1074,7 @@ class BypassModule(nn.Module):
# or (batch_size, num_channels,). This is actually the
# scale on the non-residual term, so 0 correponds to bypassing
# this module.
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.bypass_scale
else:
ans = limit_param_value(self.bypass_scale,
@ -1229,7 +1230,6 @@ class SimpleDownsample(torch.nn.Module):
d_seq_len = (seq_len + ds - 1) // ds
# Pad to an exact multiple of self.downsample
if seq_len != d_seq_len * ds:
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
@ -1322,10 +1322,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
@ -1524,7 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k)
use_pos_scores = False
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
# We can't put random.random() in the same line
use_pos_scores = True
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
@ -1542,6 +1538,16 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(seq_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
(pos_scores.stride(0),
pos_scores.stride(1),
@ -1551,7 +1557,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif self.training and random.random() < 0.1:
# This is a harder way of limiting the attention scores to not be
@ -1594,7 +1600,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# half-precision output for backprop purposes.
attn_weights = softmax(attn_scores, dim=-1)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif random.random() < 0.001 and not self.training:
self._print_attn_entropy(attn_weights)
@ -1672,9 +1678,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb)
if torch.jit.is_tracing():
(num_heads, batch_size, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(k_len)
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
else:
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
(pos_scores.stride(0),
pos_scores.stride(1),
@ -2136,7 +2153,7 @@ class ConvolutionModule(nn.Module):
if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
if not torch.jit.is_scripting() and chunk_size >= 0:
if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
# Not support exporting a model for simulated streaming decoding
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
x = self.depthwise_conv(x, chunk_size=chunk_size)