mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change x coeff from -0.1 to -0.08, as in 608.
This commit is contained in:
parent
7b1f093077
commit
c57eaf7979
@ -1216,7 +1216,7 @@ class TanSwish(torch.nn.Module):
|
|||||||
|
|
||||||
class SwooshLFunction(torch.autograd.Function):
|
class SwooshLFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
|
swoosh(x) = log(1 + exp(x-4)) - 0.08*x - 0.035
|
||||||
|
|
||||||
derivatives are between -0.08 and 0.92.
|
derivatives are between -0.08 and 0.92.
|
||||||
"""
|
"""
|
||||||
@ -1235,7 +1235,7 @@ class SwooshLFunction(torch.autograd.Function):
|
|||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
y = torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035
|
y = torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||||
|
|
||||||
if not requires_grad:
|
if not requires_grad:
|
||||||
return y
|
return y
|
||||||
@ -1273,7 +1273,7 @@ class SwooshL(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035
|
return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||||
return SwooshLFunction.apply(x)
|
return SwooshLFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user