Fix constants in SwooshFunction.
This commit is contained in:
parent
14267a5194
commit
2bfc38207c
@ -1235,7 +1235,7 @@ class SwooshFunction(torch.autograd.Function):
|
|||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
y = torch.logaddexp(one, x - 4) - 0.08 * x - 0.15
|
y = torch.logaddexp(one, x - 1.125) - 0.08 * x - 0.3
|
||||||
|
|
||||||
if not requires_grad:
|
if not requires_grad:
|
||||||
return y
|
return y
|
||||||
@ -1273,7 +1273,7 @@ class Swoosh(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
one = torch.tensor(1.0, dtype=x.dtype, device=x.device)
|
one = torch.tensor(1.0, dtype=x.dtype, device=x.device)
|
||||||
return torch.logaddexp(one, x - 4) - 0.08 * x - 0.15
|
return torch.logaddexp(one, x - 1.125) - 0.08 * x - 0.3
|
||||||
return SwooshFunction.apply(x)
|
return SwooshFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user