Fix ONNX export for the latest non-streaming zipformer. (#1160)

This commit is contained in:
Fangjun Kuang 2023-07-03 23:56:51 +08:00 committed by GitHub
parent c3e23ec8d2
commit 9009d028a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 4 deletions

View File

@ -25,6 +25,11 @@ import math
import torch.nn as nn
from torch import Tensor
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
max_value = torch.max(x, y)
diff = torch.abs(x - y)
return max_value + torch.log1p(torch.exp(-diff))
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
# 14 is not supported. Please feel free to request support or submit
@ -45,9 +50,7 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
# matches torch.onnx.export().
return torch.logaddexp(x, y)
elif torch.onnx.is_in_onnx_export():
max_value = torch.max(x, y)
diff = torch.abs(x - y)
return max_value + torch.log1p(torch.exp(-diff))
return logaddexp_onnx(x, y)
else:
# for torch.jit.trace()
return torch.logaddexp(x, y)
@ -1348,6 +1351,13 @@ class SwooshL(torch.nn.Module):
return k2.swoosh_l(x)
# return SwooshLFunction.apply(x)
class SwooshLOnnx(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-L activation.
"""
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
class SwooshRFunction(torch.autograd.Function):
"""
@ -1414,6 +1424,13 @@ class SwooshR(torch.nn.Module):
return k2.swoosh_r(x)
# return SwooshRFunction.apply(x)
class SwooshROnnx(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-R activation.
"""
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687
# simple version of SwooshL that does not redefine the backprop, used in
# ActivationDropoutAndLinearFunction.

View File

@ -26,7 +26,16 @@ from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
from scaling import (
Balancer,
Dropout3,
ScaleGrad,
SwooshL,
SwooshLOnnx,
SwooshR,
SwooshROnnx,
Whiten,
)
from zipformer import CompactRelPositionalEncoding
@ -75,6 +84,10 @@ def convert_scaled_to_non_scaled(
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
elif is_onnx and isinstance(m, SwooshR):
d[name] = SwooshROnnx()
elif is_onnx and isinstance(m, SwooshL):
d[name] = SwooshLOnnx()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()