mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix logaddexp for ONNX export (#1158)
This commit is contained in:
parent
98d89463f6
commit
c3e23ec8d2
@ -33,12 +33,24 @@ from torch import Tensor
|
|||||||
# The following function is to solve the above error when exporting
|
# The following function is to solve the above error when exporting
|
||||||
# models to ONNX via torch.jit.trace()
|
# models to ONNX via torch.jit.trace()
|
||||||
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
||||||
if not torch.jit.is_tracing():
|
# Caution(fangjun): Put torch.jit.is_scripting() before
|
||||||
|
# torch.onnx.is_in_onnx_export();
|
||||||
|
# otherwise, it will cause errors for torch.jit.script().
|
||||||
|
#
|
||||||
|
# torch.logaddexp() works for both torch.jit.script() and
|
||||||
|
# torch.jit.trace() but it causes errors for ONNX export.
|
||||||
|
#
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
# Note: We cannot use torch.jit.is_tracing() here as it also
|
||||||
|
# matches torch.onnx.export().
|
||||||
return torch.logaddexp(x, y)
|
return torch.logaddexp(x, y)
|
||||||
else:
|
elif torch.onnx.is_in_onnx_export():
|
||||||
max_value = torch.max(x, y)
|
max_value = torch.max(x, y)
|
||||||
diff = torch.abs(x - y)
|
diff = torch.abs(x - y)
|
||||||
return max_value + torch.log1p(torch.exp(-diff))
|
return max_value + torch.log1p(torch.exp(-diff))
|
||||||
|
else:
|
||||||
|
# for torch.jit.trace()
|
||||||
|
return torch.logaddexp(x, y)
|
||||||
|
|
||||||
class PiecewiseLinear(object):
|
class PiecewiseLinear(object):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user