Merge branch 'scaled_adam_exp467' into scaled_adam_exp472

This commit is contained in:
Daniel Povey 2022-11-23 14:13:19 +08:00
commit edd4bf5312
2 changed files with 79 additions and 2 deletions

View File

@ -928,7 +928,6 @@ class DoubleSwishFunction(torch.autograd.Function):
d = (d * ((ceil - floor) / 255.0) + floor) d = (d * ((ceil - floor) / 255.0) + floor)
return (y_grad * d) return (y_grad * d)
class DoubleSwish(torch.nn.Module): class DoubleSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return double-swish activation function which is an approximation to Swish(Swish(x)), """Return double-swish activation function which is an approximation to Swish(Swish(x)),
@ -939,6 +938,68 @@ class DoubleSwish(torch.nn.Module):
return DoubleSwishFunction.apply(x) return DoubleSwishFunction.apply(x)
class TanSwishFunction(torch.autograd.Function):
"""
double_swish(x) = tan(x) * torch.sigmoid(x-1)
entering: d/dx(tanh(x) * sigmoid(x-1))
into wolfram alpha, I see that the range of this function is
-0.0498087 <= y <= 0.417894
let's make it (as we don't know how this was rounded):
-0.0498088 <= y <= 0.417895
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if not requires_grad:
return torch.tanh(x) * torch.sigmoid(x - 1.0)
x_dtype = x.dtype
if x.dtype == torch.float16:
x = x.to(torch.float32)
with torch.cuda.amp.autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
y = torch.tanh(x) * torch.sigmoid(x - 1.0)
y.backward(gradient=torch.ones_like(y))
grad = x.grad
floor = -0.0498088
ceil = 0.417895
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
assert d_scaled.max() < 256.0
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
return y
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
d, = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.0498088
ceil = 0.417895
d = (d * ((ceil - floor) / 255.0) + floor)
return (y_grad * d)
class TanSwish(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return tan-swish activation function which is tanh(x) sigmoid(x-1)n
"""
if torch.jit.is_scripting():
return x.tanh() * torch.sigmoid(x - 1.0)
return TanSwishFunction.apply(x)
class ScheduledFloat(torch.nn.Module): class ScheduledFloat(torch.nn.Module):
""" """
This object is a torch.nn.Module only because we want it to show up in [top_level module].modules(); This object is a torch.nn.Module only because we want it to show up in [top_level module].modules();
@ -1147,6 +1208,20 @@ def _test_double_swish_deriv():
x.requires_grad = True x.requires_grad = True
y = m(x) y = m(x)
def _test_tan_swish_deriv():
x = torch.randn(10, 12, dtype=torch.double) * 3.0
x.requires_grad = True
m = TanSwish()
tol = ((1.2-(-0.043637))/255.0)
torch.autograd.gradcheck(m, x, atol=tol)
# for self-test.
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
x.requires_grad = True
y = m(x)
def _test_softmax(): def _test_softmax():
@ -1173,3 +1248,4 @@ if __name__ == "__main__":
_test_activation_balancer_magnitude() _test_activation_balancer_magnitude()
_test_basic_norm() _test_basic_norm()
_test_double_swish_deriv() _test_double_swish_deriv()
_test_tan_swish_deriv()

View File

@ -29,6 +29,7 @@ from scaling import (
BasicNorm, BasicNorm,
MaxEig, MaxEig,
DoubleSwish, DoubleSwish,
TanSwish,
ScaledConv1d, ScaledConv1d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
Whiten, Whiten,
@ -1317,7 +1318,7 @@ class AttentionSqueeze(nn.Module):
max_factor=0.02, max_factor=0.02,
min_prob=0.1, min_prob=0.1,
) )
self.bottleneck_activation = nn.Tanh() # in bottleneck self.bottleneck_activation = TanSwish() # in bottleneck
self.activation = Identity() # for diagnostics self.activation = Identity() # for diagnostics
# the next two balancers are only to stop parameter-magnitude 'drift': we have # the next two balancers are only to stop parameter-magnitude 'drift': we have