mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
zipformer2 logaddexp onnx safe (#1157)
This commit is contained in:
parent
ccd8c624dd
commit
98d89463f6
@ -36,7 +36,9 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor:
|
|||||||
if not torch.jit.is_tracing():
|
if not torch.jit.is_tracing():
|
||||||
return torch.logaddexp(x, y)
|
return torch.logaddexp(x, y)
|
||||||
else:
|
else:
|
||||||
return (x.exp() + y.exp()).log()
|
max_value = torch.max(x, y)
|
||||||
|
diff = torch.abs(x - y)
|
||||||
|
return max_value + torch.log1p(torch.exp(-diff))
|
||||||
|
|
||||||
class PiecewiseLinear(object):
|
class PiecewiseLinear(object):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user