Change x coeff from -0.1 to -0.08, as in 608.

This commit is contained in:
Daniel Povey 2022-12-04 21:15:49 +08:00
parent 7b1f093077
commit c57eaf7979

View File

@ -1216,7 +1216,7 @@ class TanSwish(torch.nn.Module):
class SwooshLFunction(torch.autograd.Function):
"""
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
derivatives are between -0.08 and 0.92.
"""
@ -1235,7 +1235,7 @@ class SwooshLFunction(torch.autograd.Function):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
y = torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035
y = torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not requires_grad:
return y
@ -1273,7 +1273,7 @@ class SwooshL(torch.nn.Module):
"""
if torch.jit.is_scripting():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
return SwooshLFunction.apply(x)