diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index fea669fd0..42c9187d9 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -296,6 +296,8 @@ def export_encoder_model_onnx( x = torch.zeros(1, 100, 80, dtype=torch.float32) x_lens = torch.tensor([100], dtype=torch.int64) + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) + torch.onnx.export( encoder_model, (x, x_lens), @@ -523,7 +525,7 @@ def main(): model.to("cpu") model.eval() - convert_scaled_to_non_scaled(model, inplace=True) + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) encoder = OnnxEncoder( encoder=model.encoder, diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 683a03461..54a5c2a6a 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -27,6 +27,7 @@ from typing import List, Tuple import torch import torch.nn as nn from scaling import Balancer, Dropout3, ScaleGrad, Whiten +from zipformer import CompactRelPositionalEncoding # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa @@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled( model: nn.Module, inplace: bool = False, is_pnnx: bool = False, + is_onnx: bool = False, ): """ Args: @@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled( If False, the input model is copied and we modify the copied version. is_pnnx: True if we are going to export the model for PNNX. + is_onnx: + True if we are going to export the model for ONNX. Return: Return a model without scaled layers. """ @@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled( for name, m in model.named_modules(): if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): d[name] = nn.Identity() + elif is_onnx and isinstance(m, CompactRelPositionalEncoding): + # We want to recreate the positional encoding vector when + # the input changes, so we have to use torch.jit.script() + # to replace torch.jit.trace() + d[name] = torch.jit.script(m) for k, v in d.items(): if "." in k: diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 660bdeb1d..2cfc29e49 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1305,12 +1305,6 @@ class CompactRelPositionalEncoding(torch.nn.Module): ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() - if torch.jit.is_tracing(): - # It assumes that the maximum input, after downsampling, won't have more than - # 10k frames. - # The first downsampling factor is 2, so the maximum input - # should contain less than 20k frames, e.g., less than 200 seconds, i.e., 3.33 minutes - max_len = 10000 self.embed_dim = embed_dim assert embed_dim % 2 == 0 self.dropout = Dropout2(dropout_rate) @@ -1327,11 +1321,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): # self.pe contains both positive and negative parts # the length of self.pe is 2 * input_len - 1 if self.pe.size(0) >= T * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) + self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]