Use torch.jit.script() for positional encoding

This commit is contained in:
Fangjun Kuang 2023-06-06 15:25:11 +08:00
parent 4cb2395186
commit d9e7f02225
3 changed files with 13 additions and 17 deletions

View File

@ -296,18 +296,7 @@ def export_encoder_model_onnx(
x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64)
# It assumes that the maximum input, after downsampling, won't have more
# than 10k frames.
# The first downsampling factor is 2, so the maximum input
# should contain less than 20k frames, e.g., less than 400 seconds,
# i.e., 3.3 minutes
#
# Note: If you want to handle a longer input audio, please increase this
# value. The downside is that it will increase the size of the model.
max_len = 10000
for name, m in encoder_model.named_modules():
if isinstance(m, CompactRelPositionalEncoding):
m.extend_pe(torch.tensor(0.0).expand(max_len))
encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
torch.onnx.export(
encoder_model,
@ -536,7 +525,7 @@ def main():
model.to("cpu")
model.eval()
convert_scaled_to_non_scaled(model, inplace=True)
convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder(
encoder=model.encoder,

View File

@ -27,6 +27,7 @@ from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
is_onnx: bool = False,
):
"""
Args:
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return:
Return a model without scaled layers.
"""
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items():
if "." in k:

View File

@ -1322,10 +1322,8 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
# if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]