mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce alpha for DoubleSwish, set it to -0.05.
This commit is contained in:
parent
2969eb5467
commit
d682ecc246
@ -1064,7 +1064,8 @@ class MaxEig(torch.nn.Module):
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
"""
|
||||
double_swish(x) = x * torch.sigmoid(x-1)
|
||||
double_swish(x) = x * (torch.sigmoid(x-1) + alpha)
|
||||
for e.g. alpha=-0.05 (user supplied).
|
||||
This is a definition, originally motivated by its close numerical
|
||||
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
||||
|
||||
@ -1079,9 +1080,10 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
||||
requires_grad = x.requires_grad
|
||||
x_dtype = x.dtype
|
||||
ctx.alpha = alpha
|
||||
if x.dtype == torch.float16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
@ -1105,6 +1107,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
assert d_scaled.max() < 256.0
|
||||
d_int = d_scaled.to(torch.uint8)
|
||||
ctx.save_for_backward(d_int)
|
||||
y = y + alpha * x
|
||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||
y = y.to(torch.float16)
|
||||
return y
|
||||
@ -1112,20 +1115,29 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||
d, = ctx.saved_tensors
|
||||
alpha = ctx.alpha
|
||||
# the same constants as used in forward pass.
|
||||
floor = -0.043637
|
||||
ceil = 1.2
|
||||
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||
return (y_grad * d)
|
||||
return (y_grad * (d + alpha)), None
|
||||
|
||||
class DoubleSwish(torch.nn.Module):
|
||||
def __init__(self,
|
||||
alpha: float = -0.05):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'alpha={}'.format(self.alpha)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||
that we approximate closely with x * sigmoid(x-1).
|
||||
"""
|
||||
if torch.jit.is_scripting():
|
||||
return x * torch.sigmoid(x - 1.0)
|
||||
return DoubleSwishFunction.apply(x)
|
||||
return x * (torch.sigmoid(x - 1.0) + self.alpha)
|
||||
return DoubleSwishFunction.apply(x, self.alpha)
|
||||
|
||||
|
||||
class TanSwishFunction(torch.autograd.Function):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user