mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge pull request #1 from csukuangfj/export_zipformer2_onnx
Fix exporting non-streaming zipformer to ONNX via torch.jit.trace()
This commit is contained in:
commit
f59e06c556
@ -13,7 +13,7 @@ as an example to show how to use this file.
|
||||
|
||||
cd egs/librispeech/ASR
|
||||
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17
|
||||
repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
repo=$(basename $repo_url)
|
||||
|
||||
@ -27,12 +27,29 @@ popd
|
||||
|
||||
2. Export the model to ONNX
|
||||
|
||||
./zipformer/export-onnx.py \ 148 ↵
|
||||
./zipformer/export-onnx.py \
|
||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
||||
--use-averaged-model 0 \
|
||||
--epoch 99 \
|
||||
--avg 1 \
|
||||
--exp-dir $repo/exp/
|
||||
--exp-dir $repo/exp \
|
||||
\
|
||||
--num-encoder-layers "2,2,3,4,3,2" \
|
||||
--downsampling-factor "1,2,4,8,4,2" \
|
||||
--feedforward-dim "512,768,1024,1536,1024,768" \
|
||||
--num-heads "4,4,4,8,4,4" \
|
||||
--encoder-dim "192,256,384,512,384,256" \
|
||||
--query-head-dim 32 \
|
||||
--value-head-dim 12 \
|
||||
--pos-head-dim 4 \
|
||||
--pos-dim 48 \
|
||||
--encoder-unmasked-dim "192,192,256,256,256,192" \
|
||||
--cnn-module-kernel "31,31,15,15,15,31" \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--causal False \
|
||||
--chunk-size "16,32,64,-1" \
|
||||
--left-context-frames "64,128,256,-1"
|
||||
|
||||
It will generate the following 3 files inside $repo/exp:
|
||||
|
||||
@ -115,7 +132,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless5/exp",
|
||||
default="zipformer/exp",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
@ -299,7 +316,7 @@ def export_encoder_model_onnx(
|
||||
"model_type": "zipformer",
|
||||
"version": "1",
|
||||
"model_author": "k2-fsa",
|
||||
"comment": "stateless7",
|
||||
"comment": "zipformer",
|
||||
}
|
||||
logging.info(f"meta_data: {meta_data}")
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ class Transducer(nn.Module):
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
|
||||
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
1
egs/librispeech/ASR/zipformer/onnx_pretrained.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/onnx_pretrained.py
|
||||
@ -26,6 +26,18 @@ import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
||||
# 14 is not supported. Please feel free to request support or submit
|
||||
# a pull request on PyTorch GitHub.
|
||||
#
|
||||
# The following function is to solve the above error when exporting
|
||||
# models to ONNX via torch.jit.trace()
|
||||
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
return torch.logaddexp(x, y)
|
||||
else:
|
||||
return (x.exp() + y.exp()).log()
|
||||
|
||||
class PiecewiseLinear(object):
|
||||
"""
|
||||
Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with
|
||||
@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module):
|
||||
|
||||
def __float__(self):
|
||||
batch_count = self.batch_count
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting():
|
||||
if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return float(self.default)
|
||||
else:
|
||||
ans = self.schedule(self.batch_count)
|
||||
@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
def softmax(x: Tensor, dim: int):
|
||||
if not x.requires_grad or torch.jit.is_scripting():
|
||||
if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x.softmax(dim=dim)
|
||||
|
||||
return SoftmaxFunction.apply(x, dim)
|
||||
@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module):
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return x
|
||||
return scale_grad(x, self.alpha)
|
||||
|
||||
@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor,
|
||||
|
||||
|
||||
def _no_op(x: Tensor) -> Tensor:
|
||||
if (torch.jit.is_scripting()):
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x
|
||||
else:
|
||||
# a no-op function that will have a node in the autograd graph,
|
||||
@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module):
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
return DoubleSwishFunction.apply(x)
|
||||
|
||||
@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-L activation.
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
if not x.requires_grad:
|
||||
return k2.swoosh_l_forward(x)
|
||||
else:
|
||||
@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-R activation.
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
if not x.requires_grad:
|
||||
return k2.swoosh_r_forward(x)
|
||||
else:
|
||||
|
||||
@ -100,7 +100,7 @@ class ConvNeXt(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return self.forward_internal(x)
|
||||
layerdrop_rate = float(self.layerdrop_rate)
|
||||
|
||||
@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
x = self.out_norm(x)
|
||||
x = self.dropout(x)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
x_lens = (x_lens - 7) // 2
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
|
||||
@ -258,7 +258,7 @@ class Zipformer2(EncoderInterface):
|
||||
if not self.causal:
|
||||
return -1, -1
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
assert len(self.chunk_size) == 1, self.chunk_size
|
||||
chunk_size = self.chunk_size[0]
|
||||
else:
|
||||
@ -267,7 +267,7 @@ class Zipformer2(EncoderInterface):
|
||||
if chunk_size == -1:
|
||||
left_context_chunks = -1
|
||||
else:
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
assert len(self.left_context_frames) == 1, self.left_context_frames
|
||||
left_context_frames = self.left_context_frames[0]
|
||||
else:
|
||||
@ -301,14 +301,14 @@ class Zipformer2(EncoderInterface):
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
outputs = []
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
feature_masks = [1.0] * len(self.encoder_dim)
|
||||
else:
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
chunk_size, left_context_chunks = self.get_chunk_info()
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# Not support exporting a model for simulating streaming decoding
|
||||
attn_mask = None
|
||||
else:
|
||||
@ -334,7 +334,7 @@ class Zipformer2(EncoderInterface):
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
assert self.output_downsampling_factor == 2
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
lengths = (x_lens + 1) // 2
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
@ -372,7 +372,7 @@ class Zipformer2(EncoderInterface):
|
||||
# t is frame index, shape (seq_len,)
|
||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||
# c is chunk index for each frame, shape (seq_len,)
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
c = t // chunk_size
|
||||
else:
|
||||
with warnings.catch_warnings():
|
||||
@ -650,7 +650,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return None
|
||||
batch_size = x.shape[1]
|
||||
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
||||
@ -695,7 +695,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_orig = src
|
||||
|
||||
# dropout rate for non-feedforward submodules
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
attention_skip_rate = 0.0
|
||||
else:
|
||||
attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0
|
||||
@ -713,7 +713,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
|
||||
selected_attn_weights = attn_weights[0:1]
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif not self.training and random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
@ -732,7 +732,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
@ -740,7 +740,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
ff2_skip_rate = 0.0
|
||||
else:
|
||||
ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
|
||||
@ -754,7 +754,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
@ -762,7 +762,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
conv_skip_rate)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
ff3_skip_rate = 0.0
|
||||
else:
|
||||
ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
|
||||
@ -968,7 +968,7 @@ class Zipformer2Encoder(nn.Module):
|
||||
pos_emb = self.encoder_pos(src)
|
||||
output = src
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
output = output * feature_mask
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
@ -980,7 +980,7 @@ class Zipformer2Encoder(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
|
||||
output = output * feature_mask
|
||||
|
||||
return output
|
||||
@ -1073,7 +1073,7 @@ class BypassModule(nn.Module):
|
||||
# or (batch_size, num_channels,). This is actually the
|
||||
# scale on the non-residual term, so 0 correponds to bypassing
|
||||
# this module.
|
||||
if torch.jit.is_scripting() or not self.training:
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||
return self.bypass_scale
|
||||
else:
|
||||
ans = limit_param_value(self.bypass_scale,
|
||||
@ -1524,7 +1524,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
use_pos_scores = False
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# We can't put random.random() in the same line
|
||||
use_pos_scores = True
|
||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
@ -1561,7 +1561,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif self.training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be
|
||||
@ -1604,7 +1604,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# half-precision output for backprop purposes.
|
||||
attn_weights = softmax(attn_scores, dim=-1)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif random.random() < 0.001 and not self.training:
|
||||
self._print_attn_entropy(attn_weights)
|
||||
@ -2146,7 +2146,7 @@ class ConvolutionModule(nn.Module):
|
||||
if src_key_padding_mask is not None:
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
|
||||
if not torch.jit.is_scripting() and chunk_size >= 0:
|
||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0:
|
||||
# Not support exporting a model for simulated streaming decoding
|
||||
assert self.causal, "Must initialize model with causal=True if you use chunk_size"
|
||||
x = self.depthwise_conv(x, chunk_size=chunk_size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user