From d9e7f022251386022a87964a79af573e6856e78f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 6 Jun 2023 15:25:11 +0800 Subject: [PATCH] Use torch.jit.script() for positional encoding --- egs/librispeech/ASR/zipformer/export-onnx.py | 15 ++------------- .../ASR/zipformer/scaling_converter.py | 9 +++++++++ egs/librispeech/ASR/zipformer/zipformer.py | 6 ++---- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/export-onnx.py b/egs/librispeech/ASR/zipformer/export-onnx.py index e8777b13c..6572f79d5 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx.py +++ b/egs/librispeech/ASR/zipformer/export-onnx.py @@ -296,18 +296,7 @@ def export_encoder_model_onnx( x = torch.zeros(1, 100, 80, dtype=torch.float32) x_lens = torch.tensor([100], dtype=torch.int64) - # It assumes that the maximum input, after downsampling, won't have more - # than 10k frames. - # The first downsampling factor is 2, so the maximum input - # should contain less than 20k frames, e.g., less than 400 seconds, - # i.e., 3.3 minutes - # - # Note: If you want to handle a longer input audio, please increase this - # value. The downside is that it will increase the size of the model. - max_len = 10000 - for name, m in encoder_model.named_modules(): - if isinstance(m, CompactRelPositionalEncoding): - m.extend_pe(torch.tensor(0.0).expand(max_len)) + encoder_model = torch.jit.trace(encoder_model, (x, x_lens)) torch.onnx.export( encoder_model, @@ -536,7 +525,7 @@ def main(): model.to("cpu") model.eval() - convert_scaled_to_non_scaled(model, inplace=True) + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) encoder = OnnxEncoder( encoder=model.encoder, diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 683a03461..54a5c2a6a 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -27,6 +27,7 @@ from typing import List, Tuple import torch import torch.nn as nn from scaling import Balancer, Dropout3, ScaleGrad, Whiten +from zipformer import CompactRelPositionalEncoding # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa @@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled( model: nn.Module, inplace: bool = False, is_pnnx: bool = False, + is_onnx: bool = False, ): """ Args: @@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled( If False, the input model is copied and we modify the copied version. is_pnnx: True if we are going to export the model for PNNX. + is_onnx: + True if we are going to export the model for ONNX. Return: Return a model without scaled layers. """ @@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled( for name, m in model.named_modules(): if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): d[name] = nn.Identity() + elif is_onnx and isinstance(m, CompactRelPositionalEncoding): + # We want to recreate the positional encoding vector when + # the input changes, so we have to use torch.jit.script() + # to replace torch.jit.trace() + d[name] = torch.jit.script(m) for k, v in d.items(): if "." in k: diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 15022947f..7dd794e0e 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -1322,10 +1322,8 @@ class CompactRelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(0) >= T * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) + # if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]