zipformer2 logaddexp onnx safe

This commit is contained in:
MicKot 2023-06-30 12:26:31 +00:00 committed by GitHub
parent ccd8c624dd
commit a5aa5169a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,7 +36,9 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
if not torch.jit.is_tracing(): if not torch.jit.is_tracing():
return torch.logaddexp(x, y) return torch.logaddexp(x, y)
else: else:
return (x.exp() + y.exp()).log() max_value = torch.max(x, y)
diff = torch.abs(x - y)
return max_value + torch.log1p(torch.exp(-diff))
class PiecewiseLinear(object): class PiecewiseLinear(object):
""" """