Merge pull request #2 from csukuangfj/export_zipformer2_onnx

Use torch.jit.script() for position encoding
This commit is contained in:
danfu 2023-06-06 15:46:24 +08:00 committed by GitHub
commit a2b8e3545b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 12 deletions

View File

@ -296,6 +296,8 @@ def export_encoder_model_onnx(
x = torch.zeros(1, 100, 80, dtype=torch.float32) x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64) x_lens = torch.tensor([100], dtype=torch.int64)
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
torch.onnx.export( torch.onnx.export(
encoder_model, encoder_model,
(x, x_lens), (x, x_lens),
@ -523,7 +525,7 @@ def main():
model.to("cpu") model.to("cpu")
model.eval() model.eval()
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder( encoder = OnnxEncoder(
encoder=model.encoder, encoder=model.encoder,

View File

@ -27,6 +27,7 @@ from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import Balancer, Dropout3, ScaleGrad, Whiten from scaling import Balancer, Dropout3, ScaleGrad, Whiten
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
model: nn.Module, model: nn.Module,
inplace: bool = False, inplace: bool = False,
is_pnnx: bool = False, is_pnnx: bool = False,
is_onnx: bool = False,
): ):
""" """
Args: Args:
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
If False, the input model is copied and we modify the copied version. If False, the input model is copied and we modify the copied version.
is_pnnx: is_pnnx:
True if we are going to export the model for PNNX. True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return: Return:
Return a model without scaled layers. Return a model without scaled layers.
""" """
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
for name, m in model.named_modules(): for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity() d[name] = nn.Identity()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items(): for k, v in d.items():
if "." in k: if "." in k:

View File

@ -1305,12 +1305,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
) -> None: ) -> None:
"""Construct a CompactRelPositionalEncoding object.""" """Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__() super(CompactRelPositionalEncoding, self).__init__()
if torch.jit.is_tracing():
# It assumes that the maximum input, after downsampling, won't have more than
# 10k frames.
# The first downsampling factor is 2, so the maximum input
# should contain less than 20k frames, e.g., less than 200 seconds, i.e., 3.33 minutes
max_len = 10000
self.embed_dim = embed_dim self.embed_dim = embed_dim
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
self.dropout = Dropout2(dropout_rate) self.dropout = Dropout2(dropout_rate)
@ -1327,11 +1321,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1: if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device self.pe = self.pe.to(dtype=x.dtype, device=x.device)
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]