zipformer2 logaddexp onnx safe (#1157)

This commit is contained in:
MicKot 2023-06-30 15:16:40 +02:00 committed by GitHub
parent ccd8c624dd
commit 98d89463f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,7 +36,9 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
if not torch.jit.is_tracing():
return torch.logaddexp(x, y)
else:
return (x.exp() + y.exp()).log()
max_value = torch.max(x, y)
diff = torch.abs(x - y)
return max_value + torch.log1p(torch.exp(-diff))
class PiecewiseLinear(object):
"""