diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 885f8f143..4ee7b7826 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -25,6 +25,11 @@ import math import torch.nn as nn from torch import Tensor +def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: + max_value = torch.max(x, y) + diff = torch.abs(x - y) + return max_value + torch.log1p(torch.exp(-diff)) + # RuntimeError: Exporting the operator logaddexp to ONNX opset version # 14 is not supported. Please feel free to request support or submit @@ -45,9 +50,7 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: # matches torch.onnx.export(). return torch.logaddexp(x, y) elif torch.onnx.is_in_onnx_export(): - max_value = torch.max(x, y) - diff = torch.abs(x - y) - return max_value + torch.log1p(torch.exp(-diff)) + return logaddexp_onnx(x, y) else: # for torch.jit.trace() return torch.logaddexp(x, y) @@ -1348,6 +1351,13 @@ class SwooshL(torch.nn.Module): return k2.swoosh_l(x) # return SwooshLFunction.apply(x) +class SwooshLOnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-L activation. + """ + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 4.0) - 0.08 * x - 0.035 + class SwooshRFunction(torch.autograd.Function): """ @@ -1414,6 +1424,13 @@ class SwooshR(torch.nn.Module): return k2.swoosh_r(x) # return SwooshRFunction.apply(x) +class SwooshROnnx(torch.nn.Module): + def forward(self, x: Tensor) -> Tensor: + """Return Swoosh-R activation. + """ + zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) + return logaddexp_onnx(zero, x - 1.) - 0.08 * x - 0.313261687 + # simple version of SwooshL that does not redefine the backprop, used in # ActivationDropoutAndLinearFunction. diff --git a/egs/librispeech/ASR/zipformer/scaling_converter.py b/egs/librispeech/ASR/zipformer/scaling_converter.py index 54a5c2a6a..76622fa12 100644 --- a/egs/librispeech/ASR/zipformer/scaling_converter.py +++ b/egs/librispeech/ASR/zipformer/scaling_converter.py @@ -26,7 +26,16 @@ from typing import List, Tuple import torch import torch.nn as nn -from scaling import Balancer, Dropout3, ScaleGrad, Whiten +from scaling import ( + Balancer, + Dropout3, + ScaleGrad, + SwooshL, + SwooshLOnnx, + SwooshR, + SwooshROnnx, + Whiten, +) from zipformer import CompactRelPositionalEncoding @@ -75,6 +84,10 @@ def convert_scaled_to_non_scaled( for name, m in model.named_modules(): if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)): d[name] = nn.Identity() + elif is_onnx and isinstance(m, SwooshR): + d[name] = SwooshROnnx() + elif is_onnx and isinstance(m, SwooshL): + d[name] = SwooshLOnnx() elif is_onnx and isinstance(m, CompactRelPositionalEncoding): # We want to recreate the positional encoding vector when # the input changes, so we have to use torch.jit.script()