mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change DoubleSwish formulation, add alpha*x only for x.abs() > 0.15.
This commit is contained in:
parent
8976e1e43b
commit
983a690c63
@ -1065,6 +1065,7 @@ class MaxEig(torch.nn.Module):
|
|||||||
class DoubleSwishFunction(torch.autograd.Function):
|
class DoubleSwishFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
double_swish(x) = x * (torch.sigmoid(x-1) + alpha)
|
double_swish(x) = x * (torch.sigmoid(x-1) + alpha)
|
||||||
|
|
||||||
for e.g. alpha=-0.05 (user supplied).
|
for e.g. alpha=-0.05 (user supplied).
|
||||||
This is a definition, originally motivated by its close numerical
|
This is a definition, originally motivated by its close numerical
|
||||||
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
||||||
@ -1080,26 +1081,36 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
requires_grad = x.requires_grad
|
requires_grad = x.requires_grad
|
||||||
x_dtype = x.dtype
|
x_dtype = x.dtype
|
||||||
ctx.alpha = alpha
|
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = x.to(torch.float32)
|
x = x.to(torch.float32)
|
||||||
|
|
||||||
s = torch.sigmoid(x - 1.0)
|
s = torch.sigmoid(x - 1.0)
|
||||||
y = x * s
|
y = x * s
|
||||||
|
|
||||||
|
alpha = -0.05
|
||||||
|
beta = 0.05
|
||||||
|
x_limit = 0.15
|
||||||
|
|
||||||
|
# another part of this formula is:
|
||||||
|
# ... + 0.2 * x.clamp(min=-0.15, max=0.15)
|
||||||
|
# the deriv of this is
|
||||||
|
# beta * (x.abs() < x_limit).
|
||||||
|
|
||||||
if requires_grad:
|
if requires_grad:
|
||||||
deriv = (y * (1 - s) + s)
|
deriv = (y * (1 - s) + s) # ignores the alpha part.
|
||||||
|
deriv = deriv + (x.abs() < x_limit) * beta
|
||||||
|
|
||||||
# notes on derivative of x * sigmoid(x - 1):
|
# notes on derivative of x * sigmoid(x - 1):
|
||||||
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
||||||
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
# min \simeq -0.043638. Take floor as -0.044 so it's a lower bund
|
||||||
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
||||||
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
||||||
# floors), should be expectation-preserving.
|
# floors), should be expectation-preserving.
|
||||||
floor = -0.043637
|
floor = -0.044
|
||||||
ceil = 1.2
|
ceil = 1.2 + beta
|
||||||
d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv))
|
d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv))
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# for self-testing only.
|
# for self-testing only.
|
||||||
@ -1107,7 +1118,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
assert d_scaled.max() < 256.0
|
assert d_scaled.max() < 256.0
|
||||||
d_int = d_scaled.to(torch.uint8)
|
d_int = d_scaled.to(torch.uint8)
|
||||||
ctx.save_for_backward(d_int)
|
ctx.save_for_backward(d_int)
|
||||||
y = y + alpha * x
|
y = y + alpha * x + beta * x.clamp(min=-x_limit, max=x_limit)
|
||||||
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
||||||
y = y.to(torch.float16)
|
y = y.to(torch.float16)
|
||||||
return y
|
return y
|
||||||
@ -1115,29 +1126,27 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, y_grad: Tensor) -> Tensor:
|
def backward(ctx, y_grad: Tensor) -> Tensor:
|
||||||
d, = ctx.saved_tensors
|
d, = ctx.saved_tensors
|
||||||
alpha = ctx.alpha
|
|
||||||
# the same constants as used in forward pass.
|
# the same constants as used in forward pass.
|
||||||
|
alpha = -0.05
|
||||||
|
beta = 0.05
|
||||||
floor = -0.043637
|
floor = -0.043637
|
||||||
ceil = 1.2
|
ceil = 1.2 + beta
|
||||||
|
|
||||||
d = (d * ((ceil - floor) / 255.0) + floor)
|
d = (d * ((ceil - floor) / 255.0) + floor)
|
||||||
return (y_grad * (d + alpha)), None
|
return (y_grad * (d + alpha))
|
||||||
|
|
||||||
class DoubleSwish(torch.nn.Module):
|
class DoubleSwish(torch.nn.Module):
|
||||||
def __init__(self,
|
def __init__(self):
|
||||||
alpha: float = -0.05):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.alpha = alpha
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
return 'alpha={}'.format(self.alpha)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
||||||
that we approximate closely with x * sigmoid(x-1).
|
that we approximate closely with x * sigmoid(x-1).
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting():
|
if torch.jit.is_scripting():
|
||||||
return x * (torch.sigmoid(x - 1.0) + self.alpha)
|
return x * (torch.sigmoid(x - 1.0) - 0.05) + 0.05 * x.clamp(min=-0.15, max=0.15)
|
||||||
return DoubleSwishFunction.apply(x, self.alpha)
|
return DoubleSwishFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class TanSwishFunction(torch.autograd.Function):
|
class TanSwishFunction(torch.autograd.Function):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user