First version of swoosh

This commit is contained in:
Daniel Povey 2022-12-02 16:34:53 +08:00
parent d260b54177
commit ec10573edc

View File

@ -1212,6 +1212,72 @@ class TanSwish(torch.nn.Module):
return TanSwishFunction.apply(x)
class SwooshFunction(torch.autograd.Function):
"""
swoosh(x) = log(1 + exp(x-4)) - 0.055*x - 0.15
derivatives are between -0.055 and 1-0.055.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
x_dtype = x.dtype
if x.dtype == torch.float16:
x = x.to(torch.float32)
one = torch.tensor(1.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
y = torch.logaddexp(one, x - 4) - 0.055 * x - 0.15
if not requires_grad:
return y
y.backward(gradient = torch.ones_like(y))
grad = x.grad
floor = -0.055
ceil = 0.946 # real ceil would be 0.0945, give it extra room for roundoff.
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
assert d_scaled.max() < 256.0
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
return y
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
d, = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.055
ceil = 0.946
d = (d * ((ceil - floor) / 255.0) + floor)
return (y_grad * d)
class Swoosh(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return tan-swish activation function which is tanh(x) sigmoid(x-1)n
"""
if torch.jit.is_scripting():
one = torch.tensor(1.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(one, x - 4) - 0.055 * x - 0.15
return SwooshFunction.apply(x)
def _test_max_eig():
for proportion in [0.1, 0.5, 10.0]:
logging.info(f"proportion = {proportion}")
@ -1368,6 +1434,19 @@ def _test_tan_swish_deriv():
x.requires_grad = True
y = m(x)
def _test_swoosh_deriv():
x = torch.randn(10, 12, dtype=torch.double) * 3.0
x.requires_grad = True
m = Swoosh()
tol = (1.0 / 255.0)
torch.autograd.gradcheck(m, x, atol=tol)
# for self-test.
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
x.requires_grad = True
y = m(x)
def _test_softmax():
@ -1395,3 +1474,4 @@ if __name__ == "__main__":
_test_basic_norm()
_test_double_swish_deriv()
_test_tan_swish_deriv()
_test_swoosh_deriv()