From dd91c89f28dbd1af11cee49deb06a9ad00e9844c Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 2 Jun 2023 14:21:28 +0800 Subject: [PATCH 1/2] Fix exporting non-streaming zipformer to ONNX via torch.jit.trace() --- egs/librispeech/ASR/zipformer/export-onnx.py | 25 ++++++++++-- egs/librispeech/ASR/zipformer/model.py | 2 +- .../ASR/zipformer/onnx_pretrained.py | 1 + egs/librispeech/ASR/zipformer/scaling.py | 30 +++++++++----- egs/librispeech/ASR/zipformer/subsampling.py | 4 +- egs/librispeech/ASR/zipformer/zipformer.py | 40 +++++++++---------- 6 files changed, 66 insertions(+), 36 deletions(-) create mode 120000 egs/librispeech/ASR/zipformer/onnx_pretrained.py diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 05d4a877c..57324b180 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -13,7 +13,7 @@ as an example to show how to use this file. cd egs/librispeech/ASR -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) @@ -27,12 +27,29 @@ popd 2. Export the model to ONNX -./zipformer/export-onnx.py \ 148 ↵ +./zipformer/export-onnx.py \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ --use-averaged-model 0 \ --epoch 99 \ --avg 1 \ - --exp-dir $repo/exp/ + --exp-dir $repo/exp \ + \ + --num-encoder-layers "2,2,3,4,3,2" \ + --downsampling-factor "1,2,4,8,4,2" \ + --feedforward-dim "512,768,1024,1536,1024,768" \ + --num-heads "4,4,4,8,4,4" \ + --encoder-dim "192,256,384,512,384,256" \ + --query-head-dim 32 \ + --value-head-dim 12 \ + --pos-head-dim 4 \ + --pos-dim 48 \ + --encoder-unmasked-dim "192,192,256,256,256,192" \ + --cnn-module-kernel "31,31,15,15,15,31" \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --causal False \ + --chunk-size "16,32,64,-1" \ + --left-context-frames "64,128,256,-1" It will generate the following 3 files inside $repo/exp: @@ -299,7 +316,7 @@ def export_encoder_model_onnx( "model_type": "zipformer", "version": "1", "model_author": "k2-fsa", - "comment": "stateless7", + "comment": "zipformer", } logging.info(f"meta_data: {meta_data}") diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 7fcab04ae..ea2b4b721 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -49,7 +49,7 @@ class Transducer(nn.Module): encoder: It is the transcription network in the paper. Its accepts two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and + It returns two tensors: `logits` of shape (N, T, encoder_dim) and `logit_lens` of shape (N,). decoder: It is the prediction network in the paper. Its input shape diff --git a/egs/librispeech/ASR/zipformer/onnx_pretrained.py b/egs/librispeech/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..0069288fe --- /dev/null +++ b/egs/librispeech/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/onnx_pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 908b60938..9f23eeead 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -26,6 +26,18 @@ import torch.nn as nn from torch import Tensor +# RuntimeError: Exporting the operator logaddexp to ONNX opset version +# 14 is not supported. Please feel free to request support or submit +# a pull request on PyTorch GitHub. +# +# The following function is to solve the above error when exporting +# models to ONNX via torch.jit.trace() +def logaddexp(x: Tensor, y: Tensor) -> Tensor: + if not torch.jit.is_tracing(): + return torch.logaddexp(x, y) + else: + return (x.exp() + y.exp()).log() + class PiecewiseLinear(object): """ Piecewise linear function, from float to float, specified as nonempty list of (x,y) pairs with @@ -162,7 +174,7 @@ class ScheduledFloat(torch.nn.Module): def __float__(self): batch_count = self.batch_count - if batch_count is None or not self.training or torch.jit.is_scripting(): + if batch_count is None or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): return float(self.default) else: ans = self.schedule(self.batch_count) @@ -268,7 +280,7 @@ class SoftmaxFunction(torch.autograd.Function): def softmax(x: Tensor, dim: int): - if not x.requires_grad or torch.jit.is_scripting(): + if not x.requires_grad or torch.jit.is_scripting() or torch.jit.is_tracing(): return x.softmax(dim=dim) return SoftmaxFunction.apply(x, dim) @@ -1073,7 +1085,7 @@ class ScaleGrad(nn.Module): self.alpha = alpha def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return x return scale_grad(x, self.alpha) @@ -1115,7 +1127,7 @@ def limit_param_value(x: Tensor, def _no_op(x: Tensor) -> Tensor: - if (torch.jit.is_scripting()): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x else: # a no-op function that will have a node in the autograd graph, @@ -1198,7 +1210,7 @@ class DoubleSwish(torch.nn.Module): """Return double-swish activation function which is an approximation to Swish(Swish(x)), that we approximate closely with x * sigmoid(x-1). """ - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): return x * torch.sigmoid(x - 1.0) return DoubleSwishFunction.apply(x) @@ -1313,9 +1325,9 @@ class SwooshL(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return Swoosh-L activation. """ - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 + return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 if not x.requires_grad: return k2.swoosh_l_forward(x) else: @@ -1379,9 +1391,9 @@ class SwooshR(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return Swoosh-R activation. """ - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) - return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 + return logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 if not x.requires_grad: return k2.swoosh_r_forward(x) else: diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 47403f13c..d6bf57db4 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -100,7 +100,7 @@ class ConvNeXt(nn.Module): ) def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return self.forward_internal(x) layerdrop_rate = float(self.layerdrop_rate) @@ -322,7 +322,7 @@ class Conv2dSubsampling(nn.Module): x = self.out_norm(x) x = self.dropout(x) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): x_lens = (x_lens - 7) // 2 else: with warnings.catch_warnings(): diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index ab3c2cfc6..ea4e6711f 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -258,7 +258,7 @@ class Zipformer2(EncoderInterface): if not self.causal: return -1, -1 - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): assert len(self.chunk_size) == 1, self.chunk_size chunk_size = self.chunk_size[0] else: @@ -267,7 +267,7 @@ class Zipformer2(EncoderInterface): if chunk_size == -1: left_context_chunks = -1 else: - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): assert len(self.left_context_frames) == 1, self.left_context_frames left_context_frames = self.left_context_frames[0] else: @@ -301,14 +301,14 @@ class Zipformer2(EncoderInterface): of frames in `embeddings` before padding. """ outputs = [] - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): feature_masks = [1.0] * len(self.encoder_dim) else: feature_masks = self.get_feature_masks(x) chunk_size, left_context_chunks = self.get_chunk_info() - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): # Not support exporting a model for simulating streaming decoding attn_mask = None else: @@ -334,7 +334,7 @@ class Zipformer2(EncoderInterface): x = self.downsample_output(x) # class Downsample has this rounding behavior.. assert self.output_downsampling_factor == 2 - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): lengths = (x_lens + 1) // 2 else: with warnings.catch_warnings(): @@ -372,7 +372,7 @@ class Zipformer2(EncoderInterface): # t is frame index, shape (seq_len,) t = torch.arange(seq_len, dtype=torch.int32, device=x.device) # c is chunk index for each frame, shape (seq_len,) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): c = t // chunk_size else: with warnings.catch_warnings(): @@ -650,7 +650,7 @@ class Zipformer2EncoderLayer(nn.Module): ) def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: - if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): + if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting() or torch.jit.is_tracing(): return None batch_size = x.shape[1] mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) @@ -695,7 +695,7 @@ class Zipformer2EncoderLayer(nn.Module): src_orig = src # dropout rate for non-feedforward submodules - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): attention_skip_rate = 0.0 else: attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 @@ -713,7 +713,7 @@ class Zipformer2EncoderLayer(nn.Module): self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) selected_attn_weights = attn_weights[0:1] - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): pass elif not self.training and random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to @@ -732,7 +732,7 @@ class Zipformer2EncoderLayer(nn.Module): src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): conv_skip_rate = 0.0 else: conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 @@ -740,7 +740,7 @@ class Zipformer2EncoderLayer(nn.Module): src_key_padding_mask=src_key_padding_mask), conv_skip_rate) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): ff2_skip_rate = 0.0 else: ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0 @@ -754,7 +754,7 @@ class Zipformer2EncoderLayer(nn.Module): src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): conv_skip_rate = 0.0 else: conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 @@ -762,7 +762,7 @@ class Zipformer2EncoderLayer(nn.Module): src_key_padding_mask=src_key_padding_mask), conv_skip_rate) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): ff3_skip_rate = 0.0 else: ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0 @@ -968,7 +968,7 @@ class Zipformer2Encoder(nn.Module): pos_emb = self.encoder_pos(src) output = src - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): output = output * feature_mask for i, mod in enumerate(self.layers): @@ -980,7 +980,7 @@ class Zipformer2Encoder(nn.Module): src_key_padding_mask=src_key_padding_mask, ) - if not torch.jit.is_scripting(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): output = output * feature_mask return output @@ -1073,7 +1073,7 @@ class BypassModule(nn.Module): # or (batch_size, num_channels,). This is actually the # scale on the non-residual term, so 0 correponds to bypassing # this module. - if torch.jit.is_scripting() or not self.training: + if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: return self.bypass_scale else: ans = limit_param_value(self.bypass_scale, @@ -1524,7 +1524,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = torch.matmul(q, k) use_pos_scores = False - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): # We can't put random.random() in the same line use_pos_scores = True elif not self.training or random.random() >= float(self.pos_emb_skip_rate): @@ -1561,7 +1561,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = attn_scores + pos_scores - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): pass elif self.training and random.random() < 0.1: # This is a harder way of limiting the attention scores to not be @@ -1604,7 +1604,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # half-precision output for backprop purposes. attn_weights = softmax(attn_scores, dim=-1) - if torch.jit.is_scripting(): + if torch.jit.is_scripting() or torch.jit.is_tracing(): pass elif random.random() < 0.001 and not self.training: self._print_attn_entropy(attn_weights) @@ -2146,7 +2146,7 @@ class ConvolutionModule(nn.Module): if src_key_padding_mask is not None: x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - if not torch.jit.is_scripting() and chunk_size >= 0: + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and chunk_size >= 0: # Not support exporting a model for simulated streaming decoding assert self.causal, "Must initialize model with causal=True if you use chunk_size" x = self.depthwise_conv(x, chunk_size=chunk_size) From 7b00b34617dcba49c3c22e5f0b4ab2e7bc249d8b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 2 Jun 2023 14:23:47 +0800 Subject: [PATCH 2/2] minor fixes --- egs/librispeech/ASR/zipformer/export-onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index 57324b180..45586c048 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -132,7 +132,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless5/exp", + default="zipformer/exp", help="""It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved """,