Fix exporting non-streaming zipformer to ONNX via torch.jit.trace()

This commit is contained in:
Fangjun Kuang 2023-06-02 14:21:28 +08:00
parent f43c44236f
commit dd91c89f28
6 changed files with 66 additions and 36 deletions

View File

@ -13,7 +13,7 @@ as an example to show how to use this file.
cd egs/librispeech/ASR
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
@ -27,12 +27,29 @@ popd
2. Export the model to ONNX
./zipformer/export-onnx.py \ 148
./zipformer/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp/
--exp-dir $repo/exp \
\
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
--num-heads "4,4,4,8,4,4" \
--encoder-dim "192,256,384,512,384,256" \
--query-head-dim 32 \
--value-head-dim 12 \
--pos-head-dim 4 \
--pos-dim 48 \
--encoder-unmasked-dim "192,192,256,256,256,192" \
--cnn-module-kernel "31,31,15,15,15,31" \
--decoder-dim 512 \
--joiner-dim 512 \
--causal False \
--chunk-size "16,32,64,-1" \
--left-context-frames "64,128,256,-1"
It will generate the following 3 files inside $repo/exp:
@ -299,7 +316,7 @@ def export_encoder_model_onnx(
"model_type": "zipformer",
"version": "1",
"model_author": "k2-fsa",
"comment": "stateless7",
"comment": "zipformer",
}
logging.info(f"meta_data: {meta_data}")

View File

@ -49,7 +49,7 @@ class Transducer(nn.Module):
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/onnx_pretrained.py

View File

@ -26,6 +26,18 @@ import torch.nn as nn
from torch import Tensor
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
# 14 is not supported. Please feel free to request support or submit
# a pull request on PyTorch GitHub.
#
# The following function is to solve the above error when exporting
# models to ONNX via torch.jit.trace()
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
if not torch.jit.is_tracing():
return torch.logaddexp(x, y)
else:
return (x.exp() + y.exp()).log()
class PiecewiseLinear(object):
"""
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module):
def __float__(self):
batch_count = self.batch_count
if batch_count is None or not self.training or torch.jit.is_scripting():
if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return float(self.default)
else:
ans = self.schedule(self.batch_count)
@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function):
def softmax(x: Tensor, dim: int):
if not x.requires_grad or torch.jit.is_scripting():
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
return x.softmax(dim=dim)
return SoftmaxFunction.apply(x, dim)
@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module):
self.alpha = alpha
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return x
return scale_grad(x, self.alpha)
@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor,
def _no_op(x: Tensor) -> Tensor:
if (torch.jit.is_scripting()):
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x
else:
# a no-op function that will have a node in the autograd graph,
@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module):
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
that we approximate closely with x * sigmoid(x-1).
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
return x * torch.sigmoid(x - 1.0)
return DoubleSwishFunction.apply(x)
@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-L activation.
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad:
return k2.swoosh_l_forward(x)
else:
@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-R activation.
"""
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
if not x.requires_grad:
return k2.swoosh_r_forward(x)
else:

View File

@ -100,7 +100,7 @@ class ConvNeXt(nn.Module):
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module):
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():

View File

@ -258,7 +258,7 @@ class Zipformer2(EncoderInterface):
if not self.causal:
return -1, -1
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.chunk_size) == 1, self.chunk_size
chunk_size = self.chunk_size[0]
else:
@ -267,7 +267,7 @@ class Zipformer2(EncoderInterface):
if chunk_size == -1:
left_context_chunks = -1
else:
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert len(self.left_context_frames) == 1, self.left_context_frames
left_context_frames = self.left_context_frames[0]
else:
@ -301,14 +301,14 @@ class Zipformer2(EncoderInterface):
of frames in `embeddings` before padding.
"""
outputs = []
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
feature_masks = [1.0] * len(self.encoder_dim)
else:
feature_masks = self.get_feature_masks(x)
chunk_size, left_context_chunks = self.get_chunk_info()
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
# Not support exporting a model for simulating streaming decoding
attn_mask = None
else:
@ -334,7 +334,7 @@ class Zipformer2(EncoderInterface):
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
lengths = (x_lens + 1) // 2
else:
with warnings.catch_warnings():
@ -372,7 +372,7 @@ class Zipformer2(EncoderInterface):
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
# c is chunk index for each frame, shape (seq_len,)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
c = t // chunk_size
else:
with warnings.catch_warnings():
@ -650,7 +650,7 @@ class Zipformer2EncoderLayer(nn.Module):
)
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return None
batch_size = x.shape[1]
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
@ -695,7 +695,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_orig = src
# dropout rate for non-feedforward submodules
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
attention_skip_rate = 0.0
else:
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
@ -713,7 +713,7 @@ class Zipformer2EncoderLayer(nn.Module):
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif not self.training and random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
@ -732,7 +732,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
@ -740,7 +740,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_key_padding_mask=src_key_padding_mask),
conv_skip_rate)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
ff2_skip_rate = 0.0
else:
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
@ -754,7 +754,7 @@ class Zipformer2EncoderLayer(nn.Module):
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
@ -762,7 +762,7 @@ class Zipformer2EncoderLayer(nn.Module):
src_key_padding_mask=src_key_padding_mask),
conv_skip_rate)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
ff3_skip_rate = 0.0
else:
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
@ -968,7 +968,7 @@ class Zipformer2Encoder(nn.Module):
pos_emb = self.encoder_pos(src)
output = src
if not torch.jit.is_scripting():
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
output = output * feature_mask
for i, mod in enumerate(self.layers):
@ -980,7 +980,7 @@ class Zipformer2Encoder(nn.Module):
src_key_padding_mask=src_key_padding_mask,
)
if not torch.jit.is_scripting():
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
output = output * feature_mask
return output
@ -1073,7 +1073,7 @@ class BypassModule(nn.Module):
# or (batch_size, num_channels,). This is actually the
# scale on the non-residual term, so 0 correponds to bypassing
# this module.
if torch.jit.is_scripting() or not self.training:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.bypass_scale
else:
ans = limit_param_value(self.bypass_scale,
@ -1524,7 +1524,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k)
use_pos_scores = False
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
# We can't put random.random() in the same line
use_pos_scores = True
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
@ -1561,7 +1561,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif self.training and random.random() < 0.1:
# This is a harder way of limiting the attention scores to not be
@ -1604,7 +1604,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# half-precision output for backprop purposes.
attn_weights = softmax(attn_scores, dim=-1)
if torch.jit.is_scripting():
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif random.random() < 0.001 and not self.training:
self._print_attn_entropy(attn_weights)
@ -2146,7 +2146,7 @@ class ConvolutionModule(nn.Module):
if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
if not torch.jit.is_scripting() and chunk_size >= 0:
if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
# Not support exporting a model for simulated streaming decoding
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
x = self.depthwise_conv(x, chunk_size=chunk_size)