Fix constants in SwooshFunction.

This commit is contained in:
Daniel Povey 2022-12-02 18:37:23 +08:00
parent 14267a5194
commit 2bfc38207c

View File

@ -1235,7 +1235,7 @@ class SwooshFunction(torch.autograd.Function):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
y = torch.logaddexp(one, x - 4) - 0.08 * x - 0.15
y = torch.logaddexp(one, x - 1.125) - 0.08 * x - 0.3
if not requires_grad:
return y
@ -1273,7 +1273,7 @@ class Swoosh(torch.nn.Module):
"""
if torch.jit.is_scripting():
one = torch.tensor(1.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(one, x - 4) - 0.08 * x - 0.15
return torch.logaddexp(one, x - 1.125) - 0.08 * x - 0.3
return SwooshFunction.apply(x)