Merge pull request #2 from csukuangfj/export_zipformer2_onnx

Use torch.jit.script() for position encoding
This commit is contained in:
danfu 2023-06-06 15:46:24 +08:00 committed by GitHub
commit a2b8e3545b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 12 deletions

View File

@ -296,6 +296,8 @@ def export_encoder_model_onnx(
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
torch.onnx.export(
encoder_model,
(x, x_lens),
@ -523,7 +525,7 @@ def main():
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder(
encoder=model.encoder,

View File

@ -27,6 +27,7 @@ from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
is_onnx: bool = False,
):
"""
Args:
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return:
Return a model without scaled layers.
"""
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items():
if "." in k:

View File

@ -1305,12 +1305,6 @@ class CompactRelPositionalEncoding(torch.nn.Module):
) -> None:
"""Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__()
if torch.jit.is_tracing():
# It assumes that the maximum input, after downsampling, won't have more than
# 10k frames.
# The first downsampling factor is 2, so the maximum input
# should contain less than 20k frames, e.g., less than 200 seconds, i.e., 3.33 minutes
max_len = 10000
self.embed_dim = embed_dim
assert embed_dim % 2 == 0
self.dropout = Dropout2(dropout_rate)
@ -1327,11 +1321,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]