Use torch.jit.script() for positional encoding

This commit is contained in:
Fangjun Kuang 2023-06-06 15:25:11 +08:00
parent 4cb2395186
commit d9e7f02225
3 changed files with 13 additions and 17 deletions

View File

@ -296,18 +296,7 @@ def export_encoder_model_onnx(
x = torch.zeros(1, 100, 80, dtype=torch.float32) x = torch.zeros(1, 100, 80, dtype=torch.float32)
x_lens = torch.tensor([100], dtype=torch.int64) x_lens = torch.tensor([100], dtype=torch.int64)
# It assumes that the maximum input, after downsampling, won't have more encoder_model = torch.jit.trace(encoder_model, (x, x_lens))
# than 10k frames.
# The first downsampling factor is 2, so the maximum input
# should contain less than 20k frames, e.g., less than 400 seconds,
# i.e., 3.3 minutes
#
# Note: If you want to handle a longer input audio, please increase this
# value. The downside is that it will increase the size of the model.
max_len = 10000
for name, m in encoder_model.named_modules():
if isinstance(m, CompactRelPositionalEncoding):
m.extend_pe(torch.tensor(0.0).expand(max_len))
torch.onnx.export( torch.onnx.export(
encoder_model, encoder_model,
@ -536,7 +525,7 @@ def main():
model.to("cpu") model.to("cpu")
model.eval() model.eval()
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True)
encoder = OnnxEncoder( encoder = OnnxEncoder(
encoder=model.encoder, encoder=model.encoder,

View File

@ -27,6 +27,7 @@ from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import Balancer, Dropout3, ScaleGrad, Whiten from scaling import Balancer, Dropout3, ScaleGrad, Whiten
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa # Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
@ -51,6 +52,7 @@ def convert_scaled_to_non_scaled(
model: nn.Module, model: nn.Module,
inplace: bool = False, inplace: bool = False,
is_pnnx: bool = False, is_pnnx: bool = False,
is_onnx: bool = False,
): ):
""" """
Args: Args:
@ -61,6 +63,8 @@ def convert_scaled_to_non_scaled(
If False, the input model is copied and we modify the copied version. If False, the input model is copied and we modify the copied version.
is_pnnx: is_pnnx:
True if we are going to export the model for PNNX. True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return: Return:
Return a model without scaled layers. Return a model without scaled layers.
""" """
@ -71,6 +75,11 @@ def convert_scaled_to_non_scaled(
for name, m in model.named_modules(): for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity() d[name] = nn.Identity()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items(): for k, v in d.items():
if "." in k: if "." in k:

View File

@ -1322,10 +1322,8 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1: if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( # if self.pe.dtype != x.dtype or self.pe.device != x.device:
x.device self.pe = self.pe.to(dtype=x.dtype, device=x.device)
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return return
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]