mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Fix logaddexp for ONNX export (#1158)
This commit is contained in:
parent
98d89463f6
commit
c3e23ec8d2
@ -33,12 +33,24 @@ from torch import Tensor
|
||||
# The following function is to solve the above error when exporting
|
||||
# models to ONNX via torch.jit.trace()
|
||||
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
# Caution(fangjun): Put torch.jit.is_scripting() before
|
||||
# torch.onnx.is_in_onnx_export();
|
||||
# otherwise, it will cause errors for torch.jit.script().
|
||||
#
|
||||
# torch.logaddexp() works for both torch.jit.script() and
|
||||
# torch.jit.trace() but it causes errors for ONNX export.
|
||||
#
|
||||
if torch.jit.is_scripting():
|
||||
# Note: We cannot use torch.jit.is_tracing() here as it also
|
||||
# matches torch.onnx.export().
|
||||
return torch.logaddexp(x, y)
|
||||
else:
|
||||
elif torch.onnx.is_in_onnx_export():
|
||||
max_value = torch.max(x, y)
|
||||
diff = torch.abs(x - y)
|
||||
return max_value + torch.log1p(torch.exp(-diff))
|
||||
else:
|
||||
# for torch.jit.trace()
|
||||
return torch.logaddexp(x, y)
|
||||
|
||||
class PiecewiseLinear(object):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user