Fix logaddexp for ONNX export (#1158)

This commit is contained in:
Fangjun Kuang 2023-07-02 10:30:09 +08:00 committed by GitHub
parent 98d89463f6
commit c3e23ec8d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -33,12 +33,24 @@ from torch import Tensor
# The following function is to solve the above error when exporting # The following function is to solve the above error when exporting
# models to ONNX via torch.jit.trace() # models to ONNX via torch.jit.trace()
def logaddexp(x: Tensor, y: Tensor) -> Tensor: def logaddexp(x: Tensor, y: Tensor) -> Tensor:
if not torch.jit.is_tracing(): # Caution(fangjun): Put torch.jit.is_scripting() before
# torch.onnx.is_in_onnx_export();
# otherwise, it will cause errors for torch.jit.script().
#
# torch.logaddexp() works for both torch.jit.script() and
# torch.jit.trace() but it causes errors for ONNX export.
#
if torch.jit.is_scripting():
# Note: We cannot use torch.jit.is_tracing() here as it also
# matches torch.onnx.export().
return torch.logaddexp(x, y) return torch.logaddexp(x, y)
else: elif torch.onnx.is_in_onnx_export():
max_value = torch.max(x, y) max_value = torch.max(x, y)
diff = torch.abs(x - y) diff = torch.abs(x - y)
return max_value + torch.log1p(torch.exp(-diff)) return max_value + torch.log1p(torch.exp(-diff))
else:
# for torch.jit.trace()
return torch.logaddexp(x, y)
class PiecewiseLinear(object): class PiecewiseLinear(object):
""" """