Use Swoosh-R in the Conv and Swoosh-L in the feedforward.

This commit is contained in:
Daniel Povey 2022-12-04 19:18:16 +08:00
parent d214e1c352
commit 7b1f093077
2 changed files with 89 additions and 11 deletions

View File

@ -1214,7 +1214,7 @@ class TanSwish(torch.nn.Module):
class SwooshFunction(torch.autograd.Function): class SwooshLFunction(torch.autograd.Function):
""" """
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687 swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
@ -1267,14 +1267,77 @@ class SwooshFunction(torch.autograd.Function):
return (y_grad * d) return (y_grad * d)
class Swoosh(torch.nn.Module): class SwooshL(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
"""Return tan-swish activation function which is tanh(x) sigmoid(x-1)n """Return Swoosh-L activation.
""" """
if torch.jit.is_scripting(): if torch.jit.is_scripting():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035 return torch.logaddexp(zero, x - 4.0) - 0.1 * x - 0.035
return SwooshFunction.apply(x) return SwooshLFunction.apply(x)
class SwooshRFunction(torch.autograd.Function):
"""
swoosh(x) = log(1 + exp(x-1)) - 0.08*x - 0.313261687
derivatives are between -0.08 and 0.92.
"""
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
x_dtype = x.dtype
if x.dtype == torch.float16:
x = x.to(torch.float32)
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
with torch.cuda.amp.autocast(enabled=False):
with torch.enable_grad():
x = x.detach()
x.requires_grad = True
y = torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
if not requires_grad:
return y
y.backward(gradient = torch.ones_like(y))
grad = x.grad
floor = -0.08
ceil = 0.925
d_scaled = ((grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(grad))
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
assert d_scaled.max() < 256.0
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
return y
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
d, = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.08
ceil = 0.925
d = (d * ((ceil - floor) / 255.0) + floor)
return (y_grad * d)
class SwooshR(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""Return Swoosh-L activation.
"""
if torch.jit.is_scripting():
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
return SwooshRFunction.apply(x)
@ -1434,10 +1497,23 @@ def _test_tan_swish_deriv():
x.requires_grad = True x.requires_grad = True
y = m(x) y = m(x)
def _test_swoosh_deriv(): def _test_swooshl_deriv():
x = torch.randn(10, 12, dtype=torch.double) * 3.0 x = torch.randn(10, 12, dtype=torch.double) * 3.0
x.requires_grad = True x.requires_grad = True
m = Swoosh() m = SwooshL()
tol = (1.0 / 255.0)
torch.autograd.gradcheck(m, x, atol=tol)
# for self-test.
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
x.requires_grad = True
y = m(x)
def _test_swooshr_deriv():
x = torch.randn(10, 12, dtype=torch.double) * 3.0
x.requires_grad = True
m = SwooshR()
tol = (1.0 / 255.0) tol = (1.0 / 255.0)
torch.autograd.gradcheck(m, x, atol=tol) torch.autograd.gradcheck(m, x, atol=tol)
@ -1474,4 +1550,5 @@ if __name__ == "__main__":
_test_basic_norm() _test_basic_norm()
_test_double_swish_deriv() _test_double_swish_deriv()
_test_tan_swish_deriv() _test_tan_swish_deriv()
_test_swoosh_deriv() _test_swooshr_deriv()
_test_swooshl_deriv()

View File

@ -29,7 +29,8 @@ from scaling import (
BasicNorm, BasicNorm,
MaxEig, MaxEig,
DoubleSwish, DoubleSwish,
Swoosh, SwooshL,
SwooshR,
TanSwish, TanSwish,
ScaledConv1d, ScaledConv1d,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
@ -1426,7 +1427,7 @@ class FeedforwardModule(nn.Module):
min_abs=1.0, min_abs=1.0,
max_abs=5.0, max_abs=5.0,
min_prob=0.25) min_prob=0.25)
self.activation = Swoosh() self.activation = SwooshL()
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim, self.out_proj = LinearWithAuxLoss(feedforward_dim, embed_dim,
initial_scale=0.01, initial_scale=0.01,
@ -1601,11 +1602,11 @@ class ConvolutionModule(nn.Module):
channels, channel_dim=1, channels, channel_dim=1,
min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)), min_positive=ScheduledFloat((0.0, 0.1), (8000.0, 0.05)),
max_positive=1.0, max_positive=1.0,
min_abs=0.75, min_abs=ScheduledFloat((0.0, 0.2), (20000.0, 1.0)),
max_abs=10.0, max_abs=10.0,
) )
self.activation = nn.Tanh() self.activation = SwooshR()
self.whiten = Whiten(num_groups=1, self.whiten = Whiten(num_groups=1,
whitening_limit=_whitening_schedule(7.5), whitening_limit=_whitening_schedule(7.5),