mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
add updated zipformer onnx export (#1108)
Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
parent
b4c38d7547
commit
0cb71ad3bc
775
egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Executable file
775
egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Executable file
@ -0,0 +1,775 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./zipformer/export-onnx-streaming.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal True \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 64
|
||||||
|
|
||||||
|
The --chunk-size in training is "16,32,64,-1", so we select one of them
|
||||||
|
(excluding -1) during streaming export. The same applies to `--left-context`,
|
||||||
|
whose value is "64,128,256,-1".
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import str2bool, make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
A Zipformer encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
self.encoder_proj = encoder_proj
|
||||||
|
self.chunk_size = encoder.chunk_size[0]
|
||||||
|
self.left_context_len = encoder.left_context_frames[0]
|
||||||
|
self.pad_length = 7 + 2 * 3
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
states: List[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||||
|
N = x.size(0)
|
||||||
|
T = self.chunk_size * 2 + self.pad_length
|
||||||
|
x_lens = torch.tensor([T] * N, device=x.device)
|
||||||
|
left_context_len = self.left_context_len
|
||||||
|
|
||||||
|
cached_embed_left_pad = states[-2]
|
||||||
|
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
|
||||||
|
x=x,
|
||||||
|
x_lens=x_lens,
|
||||||
|
cached_left_pad=cached_embed_left_pad,
|
||||||
|
)
|
||||||
|
assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size)
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
|
||||||
|
# processed_mask is used to mask out initial states
|
||||||
|
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||||
|
x.size(0), left_context_len
|
||||||
|
)
|
||||||
|
processed_lens = states[-1] # (batch,)
|
||||||
|
# (batch, left_context_size)
|
||||||
|
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||||
|
# Update processed lengths
|
||||||
|
new_processed_lens = processed_lens + x_lens
|
||||||
|
# (batch, left_context_size + chunk_size)
|
||||||
|
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2)
|
||||||
|
encoder_states = states[:-2]
|
||||||
|
logging.info(f"len_encoder_states={len(encoder_states)}")
|
||||||
|
(
|
||||||
|
encoder_out,
|
||||||
|
encoder_out_lens,
|
||||||
|
new_encoder_states,
|
||||||
|
) = self.encoder.streaming_forward(
|
||||||
|
x=x,
|
||||||
|
x_lens=x_lens,
|
||||||
|
states=encoder_states,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2)
|
||||||
|
encoder_out = self.encoder_proj(encoder_out)
|
||||||
|
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||||
|
|
||||||
|
new_states = new_encoder_states + [
|
||||||
|
new_cached_embed_left_pad,
|
||||||
|
new_processed_lens,
|
||||||
|
]
|
||||||
|
|
||||||
|
return encoder_out, new_states
|
||||||
|
|
||||||
|
def get_init_states(
|
||||||
|
self,
|
||||||
|
batch_size: int = 1,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||||
|
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||||
|
states[-2] is the cached left padding for ConvNeXt module,
|
||||||
|
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||||
|
states[-1] is processed_lens of shape (batch,), which records the number
|
||||||
|
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||||
|
"""
|
||||||
|
states = self.encoder.get_init_states(batch_size, device)
|
||||||
|
|
||||||
|
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||||
|
states.append(embed_states)
|
||||||
|
|
||||||
|
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
||||||
|
states.append(processed_lens)
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxDecoder(nn.Module):
|
||||||
|
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder
|
||||||
|
self.decoder_proj = decoder_proj
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, context_size).
|
||||||
|
Returns
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||||
|
decoder_output = decoder_output.squeeze(1)
|
||||||
|
output = self.decoder_proj(decoder_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxJoiner(nn.Module):
|
||||||
|
"""A wrapper for the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, output_linear: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.output_linear = output_linear
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: OnnxEncoder,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
encoder_model.encoder.__class__.forward = (
|
||||||
|
encoder_model.encoder.__class__.streaming_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_chunk_len = encoder_model.chunk_size * 2
|
||||||
|
# The encoder_embed subsample features (T - 7) // 2
|
||||||
|
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||||
|
T = decode_chunk_len + encoder_model.pad_length
|
||||||
|
|
||||||
|
x = torch.rand(1, T, 80, dtype=torch.float32)
|
||||||
|
init_state = encoder_model.get_init_states()
|
||||||
|
num_encoders = len(encoder_model.encoder.encoder_dim)
|
||||||
|
logging.info(f"num_encoders: {num_encoders}")
|
||||||
|
logging.info(f"len(init_state): {len(init_state)}")
|
||||||
|
|
||||||
|
inputs = {}
|
||||||
|
input_names = ["x"]
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
output_names = ["encoder_out"]
|
||||||
|
|
||||||
|
def build_inputs_outputs(tensors, i):
|
||||||
|
assert len(tensors) == 6, len(tensors)
|
||||||
|
|
||||||
|
# (downsample_left, batch_size, key_dim)
|
||||||
|
name = f"cached_key_{i}"
|
||||||
|
logging.info(f"{name}.shape: {tensors[0].shape}")
|
||||||
|
inputs[name] = {1: "N"}
|
||||||
|
outputs[f"new_{name}"] = {1: "N"}
|
||||||
|
input_names.append(name)
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||||
|
name = f"cached_nonlin_attn_{i}"
|
||||||
|
logging.info(f"{name}.shape: {tensors[1].shape}")
|
||||||
|
inputs[name] = {1: "N"}
|
||||||
|
outputs[f"new_{name}"] = {1: "N"}
|
||||||
|
input_names.append(name)
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (downsample_left, batch_size, value_dim)
|
||||||
|
name = f"cached_val1_{i}"
|
||||||
|
logging.info(f"{name}.shape: {tensors[2].shape}")
|
||||||
|
inputs[name] = {1: "N"}
|
||||||
|
outputs[f"new_{name}"] = {1: "N"}
|
||||||
|
input_names.append(name)
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (downsample_left, batch_size, value_dim)
|
||||||
|
name = f"cached_val2_{i}"
|
||||||
|
logging.info(f"{name}.shape: {tensors[3].shape}")
|
||||||
|
inputs[name] = {1: "N"}
|
||||||
|
outputs[f"new_{name}"] = {1: "N"}
|
||||||
|
input_names.append(name)
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size, embed_dim, conv_left_pad)
|
||||||
|
name = f"cached_conv1_{i}"
|
||||||
|
logging.info(f"{name}.shape: {tensors[4].shape}")
|
||||||
|
inputs[name] = {0: "N"}
|
||||||
|
outputs[f"new_{name}"] = {0: "N"}
|
||||||
|
input_names.append(name)
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size, embed_dim, conv_left_pad)
|
||||||
|
name = f"cached_conv2_{i}"
|
||||||
|
logging.info(f"{name}.shape: {tensors[5].shape}")
|
||||||
|
inputs[name] = {0: "N"}
|
||||||
|
outputs[f"new_{name}"] = {0: "N"}
|
||||||
|
input_names.append(name)
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers))
|
||||||
|
encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim))
|
||||||
|
cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel))
|
||||||
|
ds = encoder_model.encoder.downsampling_factor
|
||||||
|
left_context_len = encoder_model.left_context_len
|
||||||
|
left_context_len = [left_context_len // k for k in ds]
|
||||||
|
left_context_len = ",".join(map(str, left_context_len))
|
||||||
|
query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim))
|
||||||
|
value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim))
|
||||||
|
num_heads = ",".join(map(str, encoder_model.encoder.num_heads))
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "zipformer2",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "streaming zipformer2",
|
||||||
|
"decode_chunk_len": str(decode_chunk_len), # 32
|
||||||
|
"T": str(T), # 32+7+2*3=45
|
||||||
|
"num_encoder_layers": num_encoder_layers,
|
||||||
|
"encoder_dims": encoder_dims,
|
||||||
|
"cnn_module_kernels": cnn_module_kernels,
|
||||||
|
"left_context_len": left_context_len,
|
||||||
|
"query_head_dims": query_head_dims,
|
||||||
|
"value_head_dims": value_head_dims,
|
||||||
|
"num_heads": num_heads,
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
for i in range(len(init_state[:-2]) // 6):
|
||||||
|
build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i)
|
||||||
|
|
||||||
|
# (batch_size, channels, left_pad, freq)
|
||||||
|
embed_states = init_state[-2]
|
||||||
|
name = "embed_states"
|
||||||
|
logging.info(f"{name}.shape: {embed_states.shape}")
|
||||||
|
inputs[name] = {0: "N"}
|
||||||
|
outputs[f"new_{name}"] = {0: "N"}
|
||||||
|
input_names.append(name)
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size,)
|
||||||
|
processed_lens = init_state[-1]
|
||||||
|
name = "processed_lens"
|
||||||
|
logging.info(f"{name}.shape: {processed_lens.shape}")
|
||||||
|
inputs[name] = {0: "N"}
|
||||||
|
outputs[f"new_{name}"] = {0: "N"}
|
||||||
|
input_names.append(name)
|
||||||
|
output_names.append(f"new_{name}")
|
||||||
|
|
||||||
|
logging.info(inputs)
|
||||||
|
logging.info(outputs)
|
||||||
|
logging.info(input_names)
|
||||||
|
logging.info(output_names)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, init_state),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=input_names,
|
||||||
|
output_names=output_names,
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N"},
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
**inputs,
|
||||||
|
**outputs,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: OnnxDecoder,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
context_size = decoder_model.decoder.context_size
|
||||||
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
|
|
||||||
|
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
y,
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"context_size": str(context_size),
|
||||||
|
"vocab_size": str(vocab_size),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"joiner_dim": str(joiner_dim),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
|
||||||
|
encoder = OnnxEncoder(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_embed=model.encoder_embed,
|
||||||
|
encoder_proj=model.joiner.encoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = OnnxDecoder(
|
||||||
|
decoder=model.decoder,
|
||||||
|
decoder_proj=model.joiner.decoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
624
egs/librispeech/ASR/zipformer/export-onnx.py
Executable file
624
egs/librispeech/ASR/zipformer/export-onnx.py
Executable file
@ -0,0 +1,624 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||||
|
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script exports a transducer model from PyTorch to ONNX.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./zipformer/export-onnx.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
\
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal False \
|
||||||
|
--chunk-size "16,32,64,-1" \
|
||||||
|
--left-context-frames "64,128,256,-1"
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
See ./onnx_pretrained.py and ./onnx_check.py for how to
|
||||||
|
use the exported ONNX models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from decoder import Decoder
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.utils import str2bool, make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=28,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context-size",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
"""Add meta data to an ONNX model. It is changed in-place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Filename of the ONNX model to be changed.
|
||||||
|
meta_data:
|
||||||
|
Key-value pairs.
|
||||||
|
"""
|
||||||
|
model = onnx.load(filename)
|
||||||
|
for key, value in meta_data.items():
|
||||||
|
meta = model.metadata_props.add()
|
||||||
|
meta.key = key
|
||||||
|
meta.value = value
|
||||||
|
|
||||||
|
onnx.save(model, filename)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxEncoder(nn.Module):
|
||||||
|
"""A wrapper for Zipformer and the encoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
A Zipformer encoder.
|
||||||
|
encoder_proj:
|
||||||
|
The projection layer for encoder from the joiner.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_embed = encoder_embed
|
||||||
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Please see the help information of Zipformer.forward
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). Its dtype is torch.int64
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing:
|
||||||
|
- encoder_out, A 3-D tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, A 1-D tensor of shape (N,)
|
||||||
|
"""
|
||||||
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2)
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2)
|
||||||
|
encoder_out = self.encoder_proj(encoder_out)
|
||||||
|
# Now encoder_out is of shape (N, T, joiner_dim)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxDecoder(nn.Module):
|
||||||
|
"""A wrapper for Decoder and the decoder_proj from the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, decoder: Decoder, decoder_proj: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder = decoder
|
||||||
|
self.decoder_proj = decoder_proj
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y:
|
||||||
|
A 2-D tensor of shape (N, context_size).
|
||||||
|
Returns
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
need_pad = False
|
||||||
|
decoder_output = self.decoder(y, need_pad=need_pad)
|
||||||
|
decoder_output = decoder_output.squeeze(1)
|
||||||
|
output = self.decoder_proj(decoder_output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxJoiner(nn.Module):
|
||||||
|
"""A wrapper for the joiner"""
|
||||||
|
|
||||||
|
def __init__(self, output_linear: nn.Linear):
|
||||||
|
super().__init__()
|
||||||
|
self.output_linear = output_linear
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
return logit
|
||||||
|
|
||||||
|
|
||||||
|
def export_encoder_model_onnx(
|
||||||
|
encoder_model: OnnxEncoder,
|
||||||
|
encoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the given encoder model to ONNX format.
|
||||||
|
The exported model has two inputs:
|
||||||
|
|
||||||
|
- x, a tensor of shape (N, T, C); dtype is torch.float32
|
||||||
|
- x_lens, a tensor of shape (N,); dtype is torch.int64
|
||||||
|
|
||||||
|
and it has two outputs:
|
||||||
|
|
||||||
|
- encoder_out, a tensor of shape (N, T', joiner_dim)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_model:
|
||||||
|
The input encoder model
|
||||||
|
encoder_filename:
|
||||||
|
The filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
x = torch.zeros(1, 100, 80, dtype=torch.float32)
|
||||||
|
x_lens = torch.tensor([100], dtype=torch.int64)
|
||||||
|
|
||||||
|
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
encoder_model,
|
||||||
|
(x, x_lens),
|
||||||
|
encoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["x", "x_lens"],
|
||||||
|
output_names=["encoder_out", "encoder_out_lens"],
|
||||||
|
dynamic_axes={
|
||||||
|
"x": {0: "N", 1: "T"},
|
||||||
|
"x_lens": {0: "N"},
|
||||||
|
"encoder_out": {0: "N", 1: "T"},
|
||||||
|
"encoder_out_lens": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"model_type": "zipformer2",
|
||||||
|
"version": "1",
|
||||||
|
"model_author": "k2-fsa",
|
||||||
|
"comment": "non-streaming zipformer2",
|
||||||
|
}
|
||||||
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
|
add_meta_data(filename=encoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_decoder_model_onnx(
|
||||||
|
decoder_model: OnnxDecoder,
|
||||||
|
decoder_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the decoder model to ONNX format.
|
||||||
|
|
||||||
|
The exported model has one input:
|
||||||
|
|
||||||
|
- y: a torch.int64 tensor of shape (N, decoder_model.context_size)
|
||||||
|
|
||||||
|
and has one output:
|
||||||
|
|
||||||
|
- decoder_out: a torch.float32 tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_model:
|
||||||
|
The decoder model to be exported.
|
||||||
|
decoder_filename:
|
||||||
|
Filename to save the exported ONNX model.
|
||||||
|
opset_version:
|
||||||
|
The opset version to use.
|
||||||
|
"""
|
||||||
|
context_size = decoder_model.decoder.context_size
|
||||||
|
vocab_size = decoder_model.decoder.vocab_size
|
||||||
|
|
||||||
|
y = torch.zeros(10, context_size, dtype=torch.int64)
|
||||||
|
torch.onnx.export(
|
||||||
|
decoder_model,
|
||||||
|
y,
|
||||||
|
decoder_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=["y"],
|
||||||
|
output_names=["decoder_out"],
|
||||||
|
dynamic_axes={
|
||||||
|
"y": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = {
|
||||||
|
"context_size": str(context_size),
|
||||||
|
"vocab_size": str(vocab_size),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=decoder_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
def export_joiner_model_onnx(
|
||||||
|
joiner_model: nn.Module,
|
||||||
|
joiner_filename: str,
|
||||||
|
opset_version: int = 11,
|
||||||
|
) -> None:
|
||||||
|
"""Export the joiner model to ONNX format.
|
||||||
|
The exported joiner model has two inputs:
|
||||||
|
|
||||||
|
- encoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
- decoder_out: a tensor of shape (N, joiner_dim)
|
||||||
|
|
||||||
|
and produces one output:
|
||||||
|
|
||||||
|
- logit: a tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
joiner_dim = joiner_model.output_linear.weight.shape[1]
|
||||||
|
logging.info(f"joiner dim: {joiner_dim}")
|
||||||
|
|
||||||
|
projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
joiner_model,
|
||||||
|
(projected_encoder_out, projected_decoder_out),
|
||||||
|
joiner_filename,
|
||||||
|
verbose=False,
|
||||||
|
opset_version=opset_version,
|
||||||
|
input_names=[
|
||||||
|
"encoder_out",
|
||||||
|
"decoder_out",
|
||||||
|
],
|
||||||
|
output_names=["logit"],
|
||||||
|
dynamic_axes={
|
||||||
|
"encoder_out": {0: "N"},
|
||||||
|
"decoder_out": {0: "N"},
|
||||||
|
"logit": {0: "N"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
meta_data = {
|
||||||
|
"joiner_dim": str(joiner_dim),
|
||||||
|
}
|
||||||
|
add_meta_data(filename=joiner_filename, meta_data=meta_data)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
|
||||||
|
|
||||||
|
encoder = OnnxEncoder(
|
||||||
|
encoder=model.encoder,
|
||||||
|
encoder_embed=model.encoder_embed,
|
||||||
|
encoder_proj=model.joiner.encoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder = OnnxDecoder(
|
||||||
|
decoder=model.decoder,
|
||||||
|
decoder_proj=model.joiner.decoder_proj,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner = OnnxJoiner(output_linear=model.joiner.output_linear)
|
||||||
|
|
||||||
|
encoder_num_param = sum([p.numel() for p in encoder.parameters()])
|
||||||
|
decoder_num_param = sum([p.numel() for p in decoder.parameters()])
|
||||||
|
joiner_num_param = sum([p.numel() for p in joiner.parameters()])
|
||||||
|
total_num_param = encoder_num_param + decoder_num_param + joiner_num_param
|
||||||
|
logging.info(f"encoder parameters: {encoder_num_param}")
|
||||||
|
logging.info(f"decoder parameters: {decoder_num_param}")
|
||||||
|
logging.info(f"joiner parameters: {joiner_num_param}")
|
||||||
|
logging.info(f"total parameters: {total_num_param}")
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
suffix = f"iter-{params.iter}"
|
||||||
|
else:
|
||||||
|
suffix = f"epoch-{params.epoch}"
|
||||||
|
|
||||||
|
suffix += f"-avg-{params.avg}"
|
||||||
|
|
||||||
|
opset_version = 13
|
||||||
|
|
||||||
|
logging.info("Exporting encoder")
|
||||||
|
encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx"
|
||||||
|
export_encoder_model_onnx(
|
||||||
|
encoder,
|
||||||
|
encoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported encoder to {encoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting decoder")
|
||||||
|
decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx"
|
||||||
|
export_decoder_model_onnx(
|
||||||
|
decoder,
|
||||||
|
decoder_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported decoder to {decoder_filename}")
|
||||||
|
|
||||||
|
logging.info("Exporting joiner")
|
||||||
|
joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx"
|
||||||
|
export_joiner_model_onnx(
|
||||||
|
joiner,
|
||||||
|
joiner_filename,
|
||||||
|
opset_version=opset_version,
|
||||||
|
)
|
||||||
|
logging.info(f"Exported joiner to {joiner_filename}")
|
||||||
|
|
||||||
|
# Generate int8 quantization models
|
||||||
|
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||||
|
|
||||||
|
logging.info("Generate int8 quantization models")
|
||||||
|
|
||||||
|
encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=encoder_filename,
|
||||||
|
model_output=encoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=decoder_filename,
|
||||||
|
model_output=decoder_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx"
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=joiner_filename,
|
||||||
|
model_output=joiner_filename_int8,
|
||||||
|
op_types_to_quantize=["MatMul"],
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
@ -49,7 +49,7 @@ class Transducer(nn.Module):
|
|||||||
encoder:
|
encoder:
|
||||||
It is the transcription network in the paper. Its accepts
|
It is the transcription network in the paper. Its accepts
|
||||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||||
`logit_lens` of shape (N,).
|
`logit_lens` of shape (N,).
|
||||||
decoder:
|
decoder:
|
||||||
It is the prediction network in the paper. Its input shape
|
It is the prediction network in the paper. Its input shape
|
||||||
|
544
egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
Executable file
544
egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
Executable file
@ -0,0 +1,544 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script loads ONNX models exported by ./export-onnx-streaming.py
|
||||||
|
and uses them to decode waves.
|
||||||
|
|
||||||
|
We use the pre-trained model from
|
||||||
|
https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
as an example to show how to use this file.
|
||||||
|
|
||||||
|
1. Download the pre-trained model
|
||||||
|
|
||||||
|
cd egs/librispeech/ASR
|
||||||
|
|
||||||
|
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
repo=$(basename $repo_url)
|
||||||
|
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "data/lang_bpe_500/bpe.model"
|
||||||
|
git lfs pull --include "exp/pretrained.pt"
|
||||||
|
|
||||||
|
cd exp
|
||||||
|
ln -s pretrained.pt epoch-99.pt
|
||||||
|
popd
|
||||||
|
|
||||||
|
2. Export the model to ONNX
|
||||||
|
|
||||||
|
./zipformer/export-onnx-streaming.py \
|
||||||
|
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||||
|
--use-averaged-model 0 \
|
||||||
|
--epoch 99 \
|
||||||
|
--avg 1 \
|
||||||
|
--exp-dir $repo/exp \
|
||||||
|
--num-encoder-layers "2,2,3,4,3,2" \
|
||||||
|
--downsampling-factor "1,2,4,8,4,2" \
|
||||||
|
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||||
|
--num-heads "4,4,4,8,4,4" \
|
||||||
|
--encoder-dim "192,256,384,512,384,256" \
|
||||||
|
--query-head-dim 32 \
|
||||||
|
--value-head-dim 12 \
|
||||||
|
--pos-head-dim 4 \
|
||||||
|
--pos-dim 48 \
|
||||||
|
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||||
|
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||||
|
--decoder-dim 512 \
|
||||||
|
--joiner-dim 512 \
|
||||||
|
--causal True \
|
||||||
|
--chunk-size 16 \
|
||||||
|
--left-context-frames 64
|
||||||
|
|
||||||
|
It will generate the following 3 files inside $repo/exp:
|
||||||
|
|
||||||
|
- encoder-epoch-99-avg-1.onnx
|
||||||
|
- decoder-epoch-99-avg-1.onnx
|
||||||
|
- joiner-epoch-99-avg-1.onnx
|
||||||
|
|
||||||
|
3. Run this file with the exported ONNX models
|
||||||
|
|
||||||
|
./zipformer/onnx_pretrained-streaming.py \
|
||||||
|
--encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \
|
||||||
|
--decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \
|
||||||
|
--joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \
|
||||||
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
|
$repo/test_wavs/1089-134686-0001.wav
|
||||||
|
|
||||||
|
Note: Even though this script only supports decoding a single file,
|
||||||
|
the exported ONNX models do support batch processing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the encoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the decoder onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the joiner onnx model. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokens",
|
||||||
|
type=str,
|
||||||
|
help="""Path to tokens.txt.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"sound_file",
|
||||||
|
type=str,
|
||||||
|
help="The input sound file to transcribe. "
|
||||||
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
|
"For example, wav and flac are supported. "
|
||||||
|
"The sample rate has to be 16kHz.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_model_filename: str,
|
||||||
|
decoder_model_filename: str,
|
||||||
|
joiner_model_filename: str,
|
||||||
|
):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.init_encoder(encoder_model_filename)
|
||||||
|
self.init_decoder(decoder_model_filename)
|
||||||
|
self.init_joiner(joiner_model_filename)
|
||||||
|
|
||||||
|
def init_encoder(self, encoder_model_filename: str):
|
||||||
|
self.encoder = ort.InferenceSession(
|
||||||
|
encoder_model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
self.init_encoder_states()
|
||||||
|
|
||||||
|
def init_encoder_states(self, batch_size: int = 1):
|
||||||
|
encoder_meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||||
|
logging.info(f"encoder_meta={encoder_meta}")
|
||||||
|
|
||||||
|
model_type = encoder_meta["model_type"]
|
||||||
|
assert model_type == "zipformer2", model_type
|
||||||
|
|
||||||
|
decode_chunk_len = int(encoder_meta["decode_chunk_len"])
|
||||||
|
T = int(encoder_meta["T"])
|
||||||
|
|
||||||
|
num_encoder_layers = encoder_meta["num_encoder_layers"]
|
||||||
|
encoder_dims = encoder_meta["encoder_dims"]
|
||||||
|
cnn_module_kernels = encoder_meta["cnn_module_kernels"]
|
||||||
|
left_context_len = encoder_meta["left_context_len"]
|
||||||
|
query_head_dims = encoder_meta["query_head_dims"]
|
||||||
|
value_head_dims = encoder_meta["value_head_dims"]
|
||||||
|
num_heads = encoder_meta["num_heads"]
|
||||||
|
|
||||||
|
def to_int_list(s):
|
||||||
|
return list(map(int, s.split(",")))
|
||||||
|
|
||||||
|
num_encoder_layers = to_int_list(num_encoder_layers)
|
||||||
|
encoder_dims = to_int_list(encoder_dims)
|
||||||
|
cnn_module_kernels = to_int_list(cnn_module_kernels)
|
||||||
|
left_context_len = to_int_list(left_context_len)
|
||||||
|
query_head_dims = to_int_list(query_head_dims)
|
||||||
|
value_head_dims = to_int_list(value_head_dims)
|
||||||
|
num_heads = to_int_list(num_heads)
|
||||||
|
|
||||||
|
logging.info(f"decode_chunk_len: {decode_chunk_len}")
|
||||||
|
logging.info(f"T: {T}")
|
||||||
|
logging.info(f"num_encoder_layers: {num_encoder_layers}")
|
||||||
|
logging.info(f"encoder_dims: {encoder_dims}")
|
||||||
|
logging.info(f"cnn_module_kernels: {cnn_module_kernels}")
|
||||||
|
logging.info(f"left_context_len: {left_context_len}")
|
||||||
|
logging.info(f"query_head_dims: {query_head_dims}")
|
||||||
|
logging.info(f"value_head_dims: {value_head_dims}")
|
||||||
|
logging.info(f"num_heads: {num_heads}")
|
||||||
|
|
||||||
|
num_encoders = len(num_encoder_layers)
|
||||||
|
|
||||||
|
self.states = []
|
||||||
|
for i in range(num_encoders):
|
||||||
|
num_layers = num_encoder_layers[i]
|
||||||
|
key_dim = query_head_dims[i] * num_heads[i]
|
||||||
|
embed_dim = encoder_dims[i]
|
||||||
|
nonlin_attn_head_dim = 3 * embed_dim // 4
|
||||||
|
value_dim = value_head_dims[i] * num_heads[i]
|
||||||
|
conv_left_pad = cnn_module_kernels[i] // 2
|
||||||
|
|
||||||
|
for layer in range(num_layers):
|
||||||
|
cached_key = torch.zeros(
|
||||||
|
left_context_len[i], batch_size, key_dim
|
||||||
|
).numpy()
|
||||||
|
cached_nonlin_attn = torch.zeros(
|
||||||
|
1, batch_size, left_context_len[i], nonlin_attn_head_dim
|
||||||
|
).numpy()
|
||||||
|
cached_val1 = torch.zeros(
|
||||||
|
left_context_len[i], batch_size, value_dim
|
||||||
|
).numpy()
|
||||||
|
cached_val2 = torch.zeros(
|
||||||
|
left_context_len[i], batch_size, value_dim
|
||||||
|
).numpy()
|
||||||
|
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||||
|
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy()
|
||||||
|
self.states += [
|
||||||
|
cached_key,
|
||||||
|
cached_nonlin_attn,
|
||||||
|
cached_val1,
|
||||||
|
cached_val2,
|
||||||
|
cached_conv1,
|
||||||
|
cached_conv2,
|
||||||
|
]
|
||||||
|
embed_states = torch.zeros(batch_size, 128, 3, 19).numpy()
|
||||||
|
self.states.append(embed_states)
|
||||||
|
processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy()
|
||||||
|
self.states.append(processed_lens)
|
||||||
|
|
||||||
|
self.num_encoders = num_encoders
|
||||||
|
|
||||||
|
self.segment = T
|
||||||
|
self.offset = decode_chunk_len
|
||||||
|
|
||||||
|
def init_decoder(self, decoder_model_filename: str):
|
||||||
|
self.decoder = ort.InferenceSession(
|
||||||
|
decoder_model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
|
||||||
|
self.context_size = int(decoder_meta["context_size"])
|
||||||
|
self.vocab_size = int(decoder_meta["vocab_size"])
|
||||||
|
|
||||||
|
logging.info(f"context_size: {self.context_size}")
|
||||||
|
logging.info(f"vocab_size: {self.vocab_size}")
|
||||||
|
|
||||||
|
def init_joiner(self, joiner_model_filename: str):
|
||||||
|
self.joiner = ort.InferenceSession(
|
||||||
|
joiner_model_filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
|
||||||
|
self.joiner_dim = int(joiner_meta["joiner_dim"])
|
||||||
|
|
||||||
|
logging.info(f"joiner_dim: {self.joiner_dim}")
|
||||||
|
|
||||||
|
def _build_encoder_input_output(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
) -> Tuple[Dict[str, np.ndarray], List[str]]:
|
||||||
|
encoder_input = {"x": x.numpy()}
|
||||||
|
encoder_output = ["encoder_out"]
|
||||||
|
|
||||||
|
def build_inputs_outputs(tensors, i):
|
||||||
|
assert len(tensors) == 6, len(tensors)
|
||||||
|
|
||||||
|
# (downsample_left, batch_size, key_dim)
|
||||||
|
name = f"cached_key_{i}"
|
||||||
|
encoder_input[name] = tensors[0]
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (1, batch_size, downsample_left, nonlin_attn_head_dim)
|
||||||
|
name = f"cached_nonlin_attn_{i}"
|
||||||
|
encoder_input[name] = tensors[1]
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (downsample_left, batch_size, value_dim)
|
||||||
|
name = f"cached_val1_{i}"
|
||||||
|
encoder_input[name] = tensors[2]
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (downsample_left, batch_size, value_dim)
|
||||||
|
name = f"cached_val2_{i}"
|
||||||
|
encoder_input[name] = tensors[3]
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size, embed_dim, conv_left_pad)
|
||||||
|
name = f"cached_conv1_{i}"
|
||||||
|
encoder_input[name] = tensors[4]
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size, embed_dim, conv_left_pad)
|
||||||
|
name = f"cached_conv2_{i}"
|
||||||
|
encoder_input[name] = tensors[5]
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
for i in range(len(self.states[:-2]) // 6):
|
||||||
|
build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i)
|
||||||
|
|
||||||
|
# (batch_size, channels, left_pad, freq)
|
||||||
|
name = "embed_states"
|
||||||
|
embed_states = self.states[-2]
|
||||||
|
encoder_input[name] = embed_states
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
# (batch_size,)
|
||||||
|
name = "processed_lens"
|
||||||
|
processed_lens = self.states[-1]
|
||||||
|
encoder_input[name] = processed_lens
|
||||||
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
|
return encoder_input, encoder_output
|
||||||
|
|
||||||
|
def _update_states(self, states: List[np.ndarray]):
|
||||||
|
self.states = states
|
||||||
|
|
||||||
|
def run_encoder(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C)
|
||||||
|
Returns:
|
||||||
|
Return a 3-D tensor of shape (N, T', joiner_dim) where
|
||||||
|
T' is usually equal to ((T-7)//2+1)//2
|
||||||
|
"""
|
||||||
|
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||||
|
|
||||||
|
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||||
|
|
||||||
|
self._update_states(out[1:])
|
||||||
|
|
||||||
|
return torch.from_numpy(out[0])
|
||||||
|
|
||||||
|
def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
decoder_input:
|
||||||
|
A 2-D tensor of shape (N, context_size)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, joiner_dim)
|
||||||
|
"""
|
||||||
|
out = self.decoder.run(
|
||||||
|
[self.decoder.get_outputs()[0].name],
|
||||||
|
{self.decoder.get_inputs()[0].name: decoder_input.numpy()},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return torch.from_numpy(out)
|
||||||
|
|
||||||
|
def run_joiner(
|
||||||
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
decoder_out:
|
||||||
|
A 2-D tensor of shape (N, joiner_dim)
|
||||||
|
Returns:
|
||||||
|
Return a 2-D tensor of shape (N, vocab_size)
|
||||||
|
"""
|
||||||
|
out = self.joiner.run(
|
||||||
|
[self.joiner.get_outputs()[0].name],
|
||||||
|
{
|
||||||
|
self.joiner.get_inputs()[0].name: encoder_out.numpy(),
|
||||||
|
self.joiner.get_inputs()[1].name: decoder_out.numpy(),
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
return torch.from_numpy(out)
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert (
|
||||||
|
sample_rate == expected_sample_rate
|
||||||
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0].contiguous())
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def create_streaming_feature_extractor() -> OnlineFeature:
|
||||||
|
"""Create a CPU streaming feature extractor.
|
||||||
|
|
||||||
|
At present, we assume it returns a fbank feature extractor with
|
||||||
|
fixed options. In the future, we will support passing in the options
|
||||||
|
from outside.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a CPU streaming feature extractor.
|
||||||
|
"""
|
||||||
|
opts = FbankOptions()
|
||||||
|
opts.device = "cpu"
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = 16000
|
||||||
|
opts.mel_opts.num_bins = 80
|
||||||
|
return OnlineFbank(opts)
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search(
|
||||||
|
model: OnnxModel,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
context_size: int,
|
||||||
|
decoder_out: Optional[torch.Tensor] = None,
|
||||||
|
hyp: Optional[List[int]] = None,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (1, T, joiner_dim)
|
||||||
|
context_size:
|
||||||
|
The context size of the decoder model.
|
||||||
|
decoder_out:
|
||||||
|
Optional. Decoder output of the previous chunk.
|
||||||
|
hyp:
|
||||||
|
Decoding results for previous chunks.
|
||||||
|
Returns:
|
||||||
|
Return the decoded results so far.
|
||||||
|
"""
|
||||||
|
|
||||||
|
blank_id = 0
|
||||||
|
|
||||||
|
if decoder_out is None:
|
||||||
|
assert hyp is None, hyp
|
||||||
|
hyp = [blank_id] * context_size
|
||||||
|
decoder_input = torch.tensor([hyp], dtype=torch.int64)
|
||||||
|
decoder_out = model.run_decoder(decoder_input)
|
||||||
|
else:
|
||||||
|
assert hyp is not None, hyp
|
||||||
|
|
||||||
|
encoder_out = encoder_out.squeeze(0)
|
||||||
|
T = encoder_out.size(0)
|
||||||
|
for t in range(T):
|
||||||
|
cur_encoder_out = encoder_out[t : t + 1]
|
||||||
|
joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0)
|
||||||
|
y = joiner_out.argmax(dim=0).item()
|
||||||
|
if y != blank_id:
|
||||||
|
hyp.append(y)
|
||||||
|
decoder_input = hyp[-context_size:]
|
||||||
|
decoder_input = torch.tensor([decoder_input], dtype=torch.int64)
|
||||||
|
decoder_out = model.run_decoder(decoder_input)
|
||||||
|
|
||||||
|
return hyp, decoder_out
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
model = OnnxModel(
|
||||||
|
encoder_model_filename=args.encoder_model_filename,
|
||||||
|
decoder_model_filename=args.decoder_model_filename,
|
||||||
|
joiner_model_filename=args.joiner_model_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
|
online_fbank = create_streaming_feature_extractor()
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {args.sound_file}")
|
||||||
|
waves = read_sound_files(
|
||||||
|
filenames=[args.sound_file],
|
||||||
|
expected_sample_rate=sample_rate,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32)
|
||||||
|
wave_samples = torch.cat([waves, tail_padding])
|
||||||
|
|
||||||
|
num_processed_frames = 0
|
||||||
|
segment = model.segment
|
||||||
|
offset = model.offset
|
||||||
|
|
||||||
|
context_size = model.context_size
|
||||||
|
hyp = None
|
||||||
|
decoder_out = None
|
||||||
|
|
||||||
|
chunk = int(1 * sample_rate) # 1 second
|
||||||
|
start = 0
|
||||||
|
while start < wave_samples.numel():
|
||||||
|
end = min(start + chunk, wave_samples.numel())
|
||||||
|
samples = wave_samples[start:end]
|
||||||
|
start += chunk
|
||||||
|
|
||||||
|
online_fbank.accept_waveform(
|
||||||
|
sampling_rate=sample_rate,
|
||||||
|
waveform=samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
while online_fbank.num_frames_ready - num_processed_frames >= segment:
|
||||||
|
frames = []
|
||||||
|
for i in range(segment):
|
||||||
|
frames.append(online_fbank.get_frame(num_processed_frames + i))
|
||||||
|
num_processed_frames += offset
|
||||||
|
frames = torch.cat(frames, dim=0)
|
||||||
|
frames = frames.unsqueeze(0)
|
||||||
|
encoder_out = model.run_encoder(frames)
|
||||||
|
hyp, decoder_out = greedy_search(
|
||||||
|
model,
|
||||||
|
encoder_out,
|
||||||
|
context_size,
|
||||||
|
decoder_out,
|
||||||
|
hyp,
|
||||||
|
)
|
||||||
|
|
||||||
|
symbol_table = k2.SymbolTable.from_file(args.tokens)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
for i in hyp[context_size:]:
|
||||||
|
text += symbol_table[i]
|
||||||
|
text = text.replace("▁", " ").strip()
|
||||||
|
|
||||||
|
logging.info(args.sound_file)
|
||||||
|
logging.info(text)
|
||||||
|
|
||||||
|
logging.info("Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../pruned_transducer_stateless7/onnx_pretrained.py
|
@ -26,6 +26,18 @@ import torch.nn as nn
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
||||||
|
# 14 is not supported. Please feel free to request support or submit
|
||||||
|
# a pull request on PyTorch GitHub.
|
||||||
|
#
|
||||||
|
# The following function is to solve the above error when exporting
|
||||||
|
# models to ONNX via torch.jit.trace()
|
||||||
|
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||||
|
if not torch.jit.is_tracing():
|
||||||
|
return torch.logaddexp(x, y)
|
||||||
|
else:
|
||||||
|
return (x.exp() + y.exp()).log()
|
||||||
|
|
||||||
class PiecewiseLinear(object):
|
class PiecewiseLinear(object):
|
||||||
"""
|
"""
|
||||||
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
||||||
@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module):
|
|||||||
|
|
||||||
def __float__(self):
|
def __float__(self):
|
||||||
batch_count = self.batch_count
|
batch_count = self.batch_count
|
||||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return float(self.default)
|
return float(self.default)
|
||||||
else:
|
else:
|
||||||
ans = self.schedule(self.batch_count)
|
ans = self.schedule(self.batch_count)
|
||||||
@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
|
|
||||||
def softmax(x: Tensor, dim: int):
|
def softmax(x: Tensor, dim: int):
|
||||||
if not x.requires_grad or torch.jit.is_scripting():
|
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return x.softmax(dim=dim)
|
return x.softmax(dim=dim)
|
||||||
|
|
||||||
return SoftmaxFunction.apply(x, dim)
|
return SoftmaxFunction.apply(x, dim)
|
||||||
@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module):
|
|||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if torch.jit.is_scripting() or not self.training:
|
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||||
return x
|
return x
|
||||||
return scale_grad(x, self.alpha)
|
return scale_grad(x, self.alpha)
|
||||||
|
|
||||||
@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def _no_op(x: Tensor) -> Tensor:
|
def _no_op(x: Tensor) -> Tensor:
|
||||||
if (torch.jit.is_scripting()):
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return x
|
return x
|
||||||
else:
|
else:
|
||||||
# a no-op function that will have a node in the autograd graph,
|
# a no-op function that will have a node in the autograd graph,
|
||||||
@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module):
|
|||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
that we approximate closely with x * sigmoid(x-1).
|
that we approximate closely with x * sigmoid(x-1).
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return x * torch.sigmoid(x - 1.0)
|
return x * torch.sigmoid(x - 1.0)
|
||||||
return DoubleSwishFunction.apply(x)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module):
|
|||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return Swoosh-L activation.
|
"""Return Swoosh-L activation.
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||||
if not x.requires_grad:
|
if not x.requires_grad:
|
||||||
return k2.swoosh_l_forward(x)
|
return k2.swoosh_l_forward(x)
|
||||||
else:
|
else:
|
||||||
@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module):
|
|||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return Swoosh-R activation.
|
"""Return Swoosh-R activation.
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||||
if not x.requires_grad:
|
if not x.requires_grad:
|
||||||
return k2.swoosh_r_forward(x)
|
return k2.swoosh_r_forward(x)
|
||||||
else:
|
else:
|
||||||
|
@ -27,6 +27,7 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
||||||
|
from zipformer import CompactRelPositionalEncoding
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||||
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
is_pnnx: bool = False,
|
is_pnnx: bool = False,
|
||||||
|
is_onnx: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
|
|||||||
If False, the input model is copied and we modify the copied version.
|
If False, the input model is copied and we modify the copied version.
|
||||||
is_pnnx:
|
is_pnnx:
|
||||||
True if we are going to export the model for PNNX.
|
True if we are going to export the model for PNNX.
|
||||||
|
is_onnx:
|
||||||
|
True if we are going to export the model for ONNX.
|
||||||
Return:
|
Return:
|
||||||
Return a model without scaled layers.
|
Return a model without scaled layers.
|
||||||
"""
|
"""
|
||||||
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
|
|||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
||||||
d[name] = nn.Identity()
|
d[name] = nn.Identity()
|
||||||
|
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
|
||||||
|
# We want to recreate the positional encoding vector when
|
||||||
|
# the input changes, so we have to use torch.jit.script()
|
||||||
|
# to replace torch.jit.trace()
|
||||||
|
d[name] = torch.jit.script(m)
|
||||||
|
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
|
@ -100,7 +100,7 @@ class ConvNeXt(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if torch.jit.is_scripting() or not self.training:
|
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||||
return self.forward_internal(x)
|
return self.forward_internal(x)
|
||||||
layerdrop_rate = float(self.layerdrop_rate)
|
layerdrop_rate = float(self.layerdrop_rate)
|
||||||
|
|
||||||
@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.out_norm(x)
|
x = self.out_norm(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
x_lens = (x_lens - 7) // 2
|
x_lens = (x_lens - 7) // 2
|
||||||
else:
|
else:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
|
@ -133,6 +133,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
||||||
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
|
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
|
||||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||||
|
self.num_encoder_layers = num_encoder_layers
|
||||||
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
|
self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
|
||||||
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
|
self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
|
||||||
pos_head_dim = _to_tuple(pos_head_dim)
|
pos_head_dim = _to_tuple(pos_head_dim)
|
||||||
@ -258,7 +259,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
if not self.causal:
|
if not self.causal:
|
||||||
return -1, -1
|
return -1, -1
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
assert len(self.chunk_size) == 1, self.chunk_size
|
assert len(self.chunk_size) == 1, self.chunk_size
|
||||||
chunk_size = self.chunk_size[0]
|
chunk_size = self.chunk_size[0]
|
||||||
else:
|
else:
|
||||||
@ -267,7 +268,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
if chunk_size == -1:
|
if chunk_size == -1:
|
||||||
left_context_chunks = -1
|
left_context_chunks = -1
|
||||||
else:
|
else:
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
assert len(self.left_context_frames) == 1, self.left_context_frames
|
assert len(self.left_context_frames) == 1, self.left_context_frames
|
||||||
left_context_frames = self.left_context_frames[0]
|
left_context_frames = self.left_context_frames[0]
|
||||||
else:
|
else:
|
||||||
@ -301,14 +302,14 @@ class Zipformer2(EncoderInterface):
|
|||||||
of frames in `embeddings` before padding.
|
of frames in `embeddings` before padding.
|
||||||
"""
|
"""
|
||||||
outputs = []
|
outputs = []
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
feature_masks = [1.0] * len(self.encoder_dim)
|
feature_masks = [1.0] * len(self.encoder_dim)
|
||||||
else:
|
else:
|
||||||
feature_masks = self.get_feature_masks(x)
|
feature_masks = self.get_feature_masks(x)
|
||||||
|
|
||||||
chunk_size, left_context_chunks = self.get_chunk_info()
|
chunk_size, left_context_chunks = self.get_chunk_info()
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
# Not support exporting a model for simulating streaming decoding
|
# Not support exporting a model for simulating streaming decoding
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
else:
|
else:
|
||||||
@ -334,7 +335,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
x = self.downsample_output(x)
|
x = self.downsample_output(x)
|
||||||
# class Downsample has this rounding behavior..
|
# class Downsample has this rounding behavior..
|
||||||
assert self.output_downsampling_factor == 2
|
assert self.output_downsampling_factor == 2
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
lengths = (x_lens + 1) // 2
|
lengths = (x_lens + 1) // 2
|
||||||
else:
|
else:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@ -372,7 +373,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
# t is frame index, shape (seq_len,)
|
# t is frame index, shape (seq_len,)
|
||||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||||
# c is chunk index for each frame, shape (seq_len,)
|
# c is chunk index for each frame, shape (seq_len,)
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
c = t // chunk_size
|
c = t // chunk_size
|
||||||
else:
|
else:
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@ -650,7 +651,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
return None
|
return None
|
||||||
batch_size = x.shape[1]
|
batch_size = x.shape[1]
|
||||||
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
||||||
@ -695,7 +696,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src_orig = src
|
src_orig = src
|
||||||
|
|
||||||
# dropout rate for non-feedforward submodules
|
# dropout rate for non-feedforward submodules
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
attention_skip_rate = 0.0
|
attention_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||||
@ -713,7 +714,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||||
|
|
||||||
selected_attn_weights = attn_weights[0:1]
|
selected_attn_weights = attn_weights[0:1]
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
elif not self.training and random.random() < float(self.const_attention_rate):
|
elif not self.training and random.random() < float(self.const_attention_rate):
|
||||||
# Make attention weights constant. The intention is to
|
# Make attention weights constant. The intention is to
|
||||||
@ -732,7 +733,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
conv_skip_rate = 0.0
|
conv_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||||
@ -740,7 +741,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask),
|
src_key_padding_mask=src_key_padding_mask),
|
||||||
conv_skip_rate)
|
conv_skip_rate)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
ff2_skip_rate = 0.0
|
ff2_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
||||||
@ -754,7 +755,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
conv_skip_rate = 0.0
|
conv_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||||
@ -762,7 +763,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask),
|
src_key_padding_mask=src_key_padding_mask),
|
||||||
conv_skip_rate)
|
conv_skip_rate)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
ff3_skip_rate = 0.0
|
ff3_skip_rate = 0.0
|
||||||
else:
|
else:
|
||||||
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
||||||
@ -968,7 +969,7 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
pos_emb = self.encoder_pos(src)
|
pos_emb = self.encoder_pos(src)
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
if not torch.jit.is_scripting():
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
@ -980,7 +981,7 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not torch.jit.is_scripting():
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@ -1073,7 +1074,7 @@ class BypassModule(nn.Module):
|
|||||||
# or (batch_size, num_channels,). This is actually the
|
# or (batch_size, num_channels,). This is actually the
|
||||||
# scale on the non-residual term, so 0 correponds to bypassing
|
# scale on the non-residual term, so 0 correponds to bypassing
|
||||||
# this module.
|
# this module.
|
||||||
if torch.jit.is_scripting() or not self.training:
|
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||||
return self.bypass_scale
|
return self.bypass_scale
|
||||||
else:
|
else:
|
||||||
ans = limit_param_value(self.bypass_scale,
|
ans = limit_param_value(self.bypass_scale,
|
||||||
@ -1229,7 +1230,6 @@ class SimpleDownsample(torch.nn.Module):
|
|||||||
d_seq_len = (seq_len + ds - 1) // ds
|
d_seq_len = (seq_len + ds - 1) // ds
|
||||||
|
|
||||||
# Pad to an exact multiple of self.downsample
|
# Pad to an exact multiple of self.downsample
|
||||||
if seq_len != d_seq_len * ds:
|
|
||||||
# right-pad src, repeating the last element.
|
# right-pad src, repeating the last element.
|
||||||
pad = d_seq_len * ds - seq_len
|
pad = d_seq_len * ds - seq_len
|
||||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||||
@ -1322,10 +1322,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(0) >= T * 2 - 1:
|
if self.pe.size(0) >= T * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
|
||||||
x.device
|
|
||||||
):
|
|
||||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1524,7 +1520,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
attn_scores = torch.matmul(q, k)
|
attn_scores = torch.matmul(q, k)
|
||||||
|
|
||||||
use_pos_scores = False
|
use_pos_scores = False
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
# We can't put random.random() in the same line
|
# We can't put random.random() in the same line
|
||||||
use_pos_scores = True
|
use_pos_scores = True
|
||||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||||
@ -1542,6 +1538,16 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# not, but let this code define which way round it is supposed to be.
|
# not, but let this code define which way round it is supposed to be.
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||||
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
|
cols = torch.arange(seq_len)
|
||||||
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
|
indexes = rows + cols
|
||||||
|
pos_scores = pos_scores.reshape(-1, n)
|
||||||
|
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||||
|
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, seq_len)
|
||||||
|
else:
|
||||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||||
(pos_scores.stride(0),
|
(pos_scores.stride(0),
|
||||||
pos_scores.stride(1),
|
pos_scores.stride(1),
|
||||||
@ -1551,7 +1557,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
attn_scores = attn_scores + pos_scores
|
attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
elif self.training and random.random() < 0.1:
|
elif self.training and random.random() < 0.1:
|
||||||
# This is a harder way of limiting the attention scores to not be
|
# This is a harder way of limiting the attention scores to not be
|
||||||
@ -1594,7 +1600,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# half-precision output for backprop purposes.
|
# half-precision output for backprop purposes.
|
||||||
attn_weights = softmax(attn_scores, dim=-1)
|
attn_weights = softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
elif random.random() < 0.001 and not self.training:
|
elif random.random() < 0.001 and not self.training:
|
||||||
self._print_attn_entropy(attn_weights)
|
self._print_attn_entropy(attn_weights)
|
||||||
@ -1672,9 +1678,20 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||||
# [where seq_len2 represents relative position.]
|
# [where seq_len2 represents relative position.]
|
||||||
pos_scores = torch.matmul(p, pos_emb)
|
pos_scores = torch.matmul(p, pos_emb)
|
||||||
|
|
||||||
|
if torch.jit.is_tracing():
|
||||||
|
(num_heads, batch_size, time1, n) = pos_scores.shape
|
||||||
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
|
cols = torch.arange(k_len)
|
||||||
|
rows = rows.repeat(batch_size * num_heads).unsqueeze(-1)
|
||||||
|
indexes = rows + cols
|
||||||
|
pos_scores = pos_scores.reshape(-1, n)
|
||||||
|
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||||
|
pos_scores = pos_scores.reshape(num_heads, batch_size, time1, k_len)
|
||||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# not, but let this code define which way round it is supposed to be.
|
# not, but let this code define which way round it is supposed to be.
|
||||||
|
else:
|
||||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, k_len),
|
||||||
(pos_scores.stride(0),
|
(pos_scores.stride(0),
|
||||||
pos_scores.stride(1),
|
pos_scores.stride(1),
|
||||||
@ -2136,7 +2153,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
if src_key_padding_mask is not None:
|
if src_key_padding_mask is not None:
|
||||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||||
|
|
||||||
if not torch.jit.is_scripting() and chunk_size >= 0:
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
|
||||||
# Not support exporting a model for simulated streaming decoding
|
# Not support exporting a model for simulated streaming decoding
|
||||||
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
||||||
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user