From c3e23ec8d2a3ed2547bd94dee7280bd3f193a47e Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 2 Jul 2023 10:30:09 +0800 Subject: [PATCH] Fix logaddexp for ONNX export (#1158) --- egs/librispeech/ASR/zipformer/scaling.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 78c4efdc1..885f8f143 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -33,12 +33,24 @@ from torch import Tensor # The following function is to solve the above error when exporting # models to ONNX via torch.jit.trace() def logaddexp(x: Tensor, y: Tensor) -> Tensor: - if not torch.jit.is_tracing(): + # Caution(fangjun): Put torch.jit.is_scripting() before + # torch.onnx.is_in_onnx_export(); + # otherwise, it will cause errors for torch.jit.script(). + # + # torch.logaddexp() works for both torch.jit.script() and + # torch.jit.trace() but it causes errors for ONNX export. + # + if torch.jit.is_scripting(): + # Note: We cannot use torch.jit.is_tracing() here as it also + # matches torch.onnx.export(). return torch.logaddexp(x, y) - else: + elif torch.onnx.is_in_onnx_export(): max_value = torch.max(x, y) diff = torch.abs(x - y) return max_value + torch.log1p(torch.exp(-diff)) + else: + # for torch.jit.trace() + return torch.logaddexp(x, y) class PiecewiseLinear(object): """