mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix ONNX export for the latest non-streaming zipformer. (#1160)
This commit is contained in:
parent
c3e23ec8d2
commit
9009d028a0
@ -25,6 +25,11 @@ import math
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
||||
max_value = torch.max(x, y)
|
||||
diff = torch.abs(x - y)
|
||||
return max_value + torch.log1p(torch.exp(-diff))
|
||||
|
||||
|
||||
# RuntimeError: Exporting the operator logaddexp to ONNX opset version
|
||||
# 14 is not supported. Please feel free to request support or submit
|
||||
@ -45,9 +50,7 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||
# matches torch.onnx.export().
|
||||
return torch.logaddexp(x, y)
|
||||
elif torch.onnx.is_in_onnx_export():
|
||||
max_value = torch.max(x, y)
|
||||
diff = torch.abs(x - y)
|
||||
return max_value + torch.log1p(torch.exp(-diff))
|
||||
return logaddexp_onnx(x, y)
|
||||
else:
|
||||
# for torch.jit.trace()
|
||||
return torch.logaddexp(x, y)
|
||||
@ -1348,6 +1351,13 @@ class SwooshL(torch.nn.Module):
|
||||
return k2.swoosh_l(x)
|
||||
# return SwooshLFunction.apply(x)
|
||||
|
||||
class SwooshLOnnx(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-L activation.
|
||||
"""
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035
|
||||
|
||||
|
||||
class SwooshRFunction(torch.autograd.Function):
|
||||
"""
|
||||
@ -1414,6 +1424,13 @@ class SwooshR(torch.nn.Module):
|
||||
return k2.swoosh_r(x)
|
||||
# return SwooshRFunction.apply(x)
|
||||
|
||||
class SwooshROnnx(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swoosh-R activation.
|
||||
"""
|
||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
|
||||
|
||||
# simple version of SwooshL that does not redefine the backprop, used in
|
||||
# ActivationDropoutAndLinearFunction.
|
||||
|
@ -26,7 +26,16 @@ from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import Balancer, Dropout3, ScaleGrad, Whiten
|
||||
from scaling import (
|
||||
Balancer,
|
||||
Dropout3,
|
||||
ScaleGrad,
|
||||
SwooshL,
|
||||
SwooshLOnnx,
|
||||
SwooshR,
|
||||
SwooshROnnx,
|
||||
Whiten,
|
||||
)
|
||||
from zipformer import CompactRelPositionalEncoding
|
||||
|
||||
|
||||
@ -75,6 +84,10 @@ def convert_scaled_to_non_scaled(
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
|
||||
d[name] = nn.Identity()
|
||||
elif is_onnx and isinstance(m, SwooshR):
|
||||
d[name] = SwooshROnnx()
|
||||
elif is_onnx and isinstance(m, SwooshL):
|
||||
d[name] = SwooshLOnnx()
|
||||
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
|
||||
# We want to recreate the positional encoding vector when
|
||||
# the input changes, so we have to use torch.jit.script()
|
||||
|
Loading…
x
Reference in New Issue
Block a user